diff --git a/backend/app/api/routes/login.py b/backend/app/api/routes/login.py index 1a196589c4..a57053b363 100644 --- a/backend/app/api/routes/login.py +++ b/backend/app/api/routes/login.py @@ -1,19 +1,22 @@ from datetime import timedelta from typing import Annotated, Any -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, Form, HTTPException from fastapi.responses import HTMLResponse from fastapi.security import OAuth2PasswordRequestForm +from pydantic import BaseModel from app import crud -from app.api.deps import CurrentUser, SessionDep, get_first_superuser +from app.api.deps import CurrentUser, RedisDep, SessionDep, get_first_superuser from app.core import security from app.core.config import settings from app.core.security import get_password_hash from app.models import Message, NewPassword, Token, UserMePublic, UserPublic from app.utils import ( + create_and_store_device_code, generate_password_reset_token, generate_reset_password_email, + generate_user_code, send_email, verify_password_reset_token, ) @@ -60,6 +63,40 @@ def login_access_token( ) +class DeviceAuthorizationResponse(BaseModel): + device_code: str + user_code: str + verification_uri: str + verification_uri_complete: str + expires_in: int + interval: int + + +@router.post("/login/device/authorization") +async def device_authorization( + client_id: Annotated[str, Form()], + redis: RedisDep, +) -> DeviceAuthorizationResponse: + """ + Device Authorization Grant + """ + user_code = generate_user_code() + + device_code = create_and_store_device_code(user_code, client_id, redis) + + verification_uri = f"{settings.server_host}/device" + verification_uri_complete = f"{verification_uri}?code={user_code}" + + return DeviceAuthorizationResponse( + device_code=str(device_code), + user_code=str(user_code), + verification_uri=verification_uri, + verification_uri_complete=verification_uri_complete, + expires_in=settings.DEVICE_AUTH_TTL_MINUTES * 60, + interval=settings.DEVICE_AUTH_POLL_INTERVAL_SECONDS, + ) + + @router.post("/login/test-token", response_model=UserPublic) def test_token(current_user: CurrentUser) -> Any: """ diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 3c3dcc8668..13a5a4802f 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -29,8 +29,13 @@ class Settings(BaseSettings): ) API_V1_STR: str = "/api/v1" SECRET_KEY: str = secrets.token_urlsafe(32) + + # AUTH # 60 minutes * 24 hours * 8 days = 8 days ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8 + DEVICE_AUTH_TTL_MINUTES: int = 5 + DEVICE_AUTH_POLL_INTERVAL_SECONDS: int = 5 + DOMAIN: str = "localhost" ENVIRONMENT: Literal["local", "staging", "production"] = "local" diff --git a/backend/app/tests/api/routes/test_device_auth.py b/backend/app/tests/api/routes/test_device_auth.py new file mode 100644 index 0000000000..8df080f1bb --- /dev/null +++ b/backend/app/tests/api/routes/test_device_auth.py @@ -0,0 +1,24 @@ +from fastapi.testclient import TestClient + +from app.core.config import settings + + +def test_get_device_code(client: TestClient) -> None: + data = {"client_id": "valid_id"} + + r = client.post(f"{settings.API_V1_STR}/login/device/authorization", data=data) + + assert r.status_code == 200 + + response_data = r.json() + + assert "device_code" in response_data + assert "user_code" in response_data + assert "expires_in" in response_data + assert "interval" in response_data + + assert response_data["verification_uri"] == f"{settings.server_host}/device" + assert ( + response_data["verification_uri_complete"] + == f"{settings.server_host}/device?code={response_data['user_code']}" + ) diff --git a/backend/app/utils.py b/backend/app/utils.py index 2a73835b85..17174ea19f 100644 --- a/backend/app/utils.py +++ b/backend/app/utils.py @@ -1,6 +1,7 @@ import logging import re import unicodedata +import uuid from dataclasses import dataclass from datetime import datetime, timedelta, timezone from enum import Enum @@ -11,6 +12,8 @@ import jwt from jinja2 import Template from jwt.exceptions import InvalidTokenError +from pydantic import BaseModel +from redis import Redis from app.core.config import settings @@ -253,3 +256,67 @@ def generate_account_deletion_email(email_to: str) -> EmailData: }, ) return EmailData(html_content=html_content, subject=subject) + + +def generate_user_code() -> str: + """Generates a unique user code for device auth.""" + + # RFC 8628 suggest to return an easy to type code, but since + # we'll automatically open the browser when authenticating + # from the CLI, it should be fine to return a uuid, this + # means we don't have to worry about potential user code + # collisions. + return str(uuid.uuid4()) + + +class DeviceAuthorizationData(BaseModel): + device_code: str + client_id: str + expires_at: datetime + + +def create_and_store_device_code( + user_code: str, client_id: str, redis: "Redis[Any]" +) -> str: + """Create a new device code and store it in Redis. + + The device code is generated and stored in Redis with the following structure: + - key: auth:device: + - value: { + "device_code": , + "client_id": , + "expires_at": + } + + Additionally, a mapping from the user code to the device code is stored in Redis with the following structure: + - key: auth:user-code: + - value: + + The device code is returned if it was successfully stored in Redis. + """ + now = get_datetime_utc() + + device_code = str(uuid.uuid4()) + + data = DeviceAuthorizationData( + device_code=device_code, + client_id=client_id, + expires_at=now + timedelta(minutes=settings.DEVICE_AUTH_TTL_MINUTES), + ) + + pipeline = redis.pipeline(True) + + pipeline.set( + f"auth:device:{device_code}", + data.model_dump_json(), + ex=settings.DEVICE_AUTH_TTL_MINUTES * 60, + ) + pipeline.set( + f"auth:user-code:{user_code}", + device_code, + ex=settings.DEVICE_AUTH_TTL_MINUTES * 60, + ) + + pipeline.execute() + + return device_code