diff --git a/src/writer/auth.py b/src/writer/auth.py index bc768ce8..64620854 100644 --- a/src/writer/auth.py +++ b/src/writer/auth.py @@ -1,5 +1,6 @@ import asyncio import dataclasses +import logging import os.path import time from abc import ABCMeta, abstractmethod @@ -16,6 +17,8 @@ from writer.serve import WriterFastAPI from writer.ss_types import InitSessionRequestPayload +logger = logging.getLogger('writer') + # Dictionary for storing failed attempts {ip_address: timestamp} failed_attempts: Dict[str, float] = {} @@ -181,11 +184,23 @@ def register(self, callback: Optional[Callable[[Request, str, dict], None]] = None, unauthorized_action: Optional[Callable[[Request, Unauthorized], Response]] = None ): + + redirect_url = urljoin(self.host_url, self.callback_authorize) + host_url_path = urlpath(self.host_url) + callback_authorize_path = urljoin(host_url_path, self.callback_authorize) + asset_assets_path = urljoin(host_url_path, "assets") + + logger.debug(f"[auth] oidc - url redirect: {redirect_url}") + logger.debug(f"[auth] oidc - endpoint authorize: {self.url_authorize}") + logger.debug(f"[auth] oidc - endpoint token: {self.url_oauthtoken}") + logger.debug(f"[auth] oidc - path: {host_url_path}") + logger.debug(f"[auth] oidc - authorize path: {callback_authorize_path}") + logger.debug(f"[auth] oidc - asset path: {asset_assets_path}") self.authlib = OAuth2Session( client_id=self.client_id, client_secret=self.client_secret, scope=self.scope.split(" "), - redirect_uri=_urljoin(self.host_url, self.callback_authorize), + redirect_uri=redirect_url, authorization_endpoint=self.url_authorize, token_endpoint=self.url_oauthtoken, ) @@ -195,10 +210,8 @@ def register(self, @asgi_app.middleware("http") async def oidc_middleware(request: Request, call_next): session = request.cookies.get('session') - host_url_path = _urlpath(self.host_url) - full_callback_authorize = '/' + _urljoin(host_url_path, self.callback_authorize) - full_assets = '/' + _urljoin(host_url_path, '/assets') - if session is not None or request.url.path in [full_callback_authorize] or request.url.path.startswith(full_assets): + + if session is not None or request.url.path in [callback_authorize_path] or request.url.path.startswith(asset_assets_path): response: Response = await call_next(request) return response else: @@ -206,11 +219,11 @@ async def oidc_middleware(request: Request, call_next): response = RedirectResponse(url=url[0]) return response - @asgi_app.get('/' + _urlstrip(self.callback_authorize)) + @asgi_app.get('/' + urlstrip(self.callback_authorize)) async def route_callback(request: Request): self.authlib.fetch_token(url=self.url_oauthtoken, authorization_response=str(request.url)) try: - host_url_path = _urlpath(self.host_url) + host_url_path = urlpath(self.host_url) response = RedirectResponse(url=host_url_path) session_id = session_manager.generate_session_id() @@ -300,44 +313,54 @@ def Auth0(client_id: str, client_secret: str, domain: str, host_url: str) -> Oid url_oauthtoken=f"https://{domain}/oauth/token", url_userinfo=f"https://{domain}/userinfo") -def _urlpath(url: str): +def urlpath(url: str): """ - >>> _urlpath("http://localhost/app1") + >>> urlpath("http://localhost/app1") >>> "/app1" + + >>> urlpath("http://localhost") + >>> "/" """ - return urlparse(url).path + path = urlparse(url).path + if len(path) == 0: + return "/" + else: + return path -def _urljoin(*args): +def urljoin(*args): """ - >>> _urljoin("http://localhost/app1", "edit") + >>> urljoin("http://localhost/app1", "edit") >>> "http://localhost/app1/edit" - >>> _urljoin("app1/", "edit") + >>> urljoin("app1/", "edit") >>> "app1/edit" - >>> _urljoin("app1", "edit") + >>> urljoin("app1", "edit") >>> "app1/edit" - >>> _urljoin("/app1/", "/edit") - >>> "app1/edit" + >>> urljoin("/app1/", "/edit") + >>> "/app1/edit" """ + root_part = args[0] + root_part_is_root_path = root_part.startswith('/') and len(root_part) > 1 + url_strip_parts = [] for part in args: if part: - url_strip_parts.append(_urlstrip(part)) + url_strip_parts.append(urlstrip(part)) - return '/'.join(url_strip_parts) + return '/'.join(url_strip_parts) if root_part_is_root_path is False else '/' + '/'.join(url_strip_parts) -def _urlstrip(url_path: str): +def urlstrip(url_path: str): """ - >>> _urlstrip("/app1/") + >>> urlstrip("/app1/") >>> "app1" - >>> _urlstrip("http://localhost/app1") + >>> urlstrip("http://localhost/app1") >>> "http://localhost/app1" - >>> _urlstrip("http://localhost/app1/") + >>> urlstrip("http://localhost/app1/") >>> "http://localhost/app1" """ return url_path.strip('/') diff --git a/tests/backend/test_auth.py b/tests/backend/test_auth.py index 7236d0cf..443ef11c 100644 --- a/tests/backend/test_auth.py +++ b/tests/backend/test_auth.py @@ -1,6 +1,8 @@ import fastapi import fastapi.testclient +import pytest import writer.serve +from writer import auth from tests.backend import test_basicauth_dir @@ -35,3 +37,38 @@ def test_basicauth_authentication_module_disabled_when_server_setup_hook_is_disa with fastapi.testclient.TestClient(asgi_app) as client: res = client.get("/api/init") assert res.status_code == 405 + + @pytest.mark.parametrize("path,expected_path", [ + ("", "/"), + ("http://localhost", "/"), + ("http://localhost/", "/"), + ("http://localhost/any", "/any"), + ("http://localhost/any/", "/any/"), + ("/any/yolo", "/any/yolo") + ]) + def test_url_path_scenarios(self, path: str, expected_path: str): + assert auth.urlpath(path) == expected_path + + @pytest.mark.parametrize("path,expected_path", [ + ("/", ""), + ("/yolo", "yolo"), + ("/yolo/", "yolo"), + ("http://localhost", "http://localhost"), + ("http://localhost/", "http://localhost"), + ("http://localhost/any", "http://localhost/any"), + ("http://localhost/any/", "http://localhost/any") + ]) + def test_url_split_scenarios(self, path: str, expected_path: str): + assert auth.urlstrip(path) == expected_path + + @pytest.mark.parametrize("path1,path2,expected_path", [ + ("/", "any", "/any"), + ("", "any", "any"), + ("/yolo", "any", "/yolo/any"), + ("/yolo", "/any", "/yolo/any"), + ("http://localhost", "any", "http://localhost/any"), + ("http://localhost/", "/any", "http://localhost/any"), + ("http://localhost/yolo", "/any", "http://localhost/yolo/any"), + ]) + def test_urljoin_scenarios(self, path1: str, path2, expected_path: str): + assert auth.urljoin(path1, path2) == expected_path