From e077edbb96523bb1f6ffd9859f7f9abc22cb6395 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 5 Sep 2024 21:36:55 -0700 Subject: [PATCH] Updated refreshing (#2327) * clean up + add environment variables * remove log * update * update api settings * somewhat cleaner refresh functionality * fully functional * update settings * validated * remove random logs * remove unneeded paramter + log * move to ee + remove comments * Cleanup unused --------- Co-authored-by: Weves --- backend/danswer/server/settings/api.py | 2 +- .../danswer/server/enterprise_settings/api.py | 117 ++++++++++++++++++ backend/shared_configs/configs.py | 2 + web/Dockerfile | 1 + web/src/app/chat/ChatPage.tsx | 4 +- web/src/app/layout.tsx | 3 +- web/src/app/search/page.tsx | 4 +- web/src/components/health/healthcheck.tsx | 94 +++++++++++--- web/src/lib/time.ts | 1 + web/src/middleware.ts | 1 + 10 files changed, 206 insertions(+), 23 deletions(-) diff --git a/backend/danswer/server/settings/api.py b/backend/danswer/server/settings/api.py index 3330f6cc5ff..5b8564c3d3a 100644 --- a/backend/danswer/server/settings/api.py +++ b/backend/danswer/server/settings/api.py @@ -66,7 +66,7 @@ def fetch_settings( return UserSettings( **general_settings.model_dump(), notifications=user_notifications, - needs_reindexing=needs_reindexing + needs_reindexing=needs_reindexing, ) diff --git a/backend/ee/danswer/server/enterprise_settings/api.py b/backend/ee/danswer/server/enterprise_settings/api.py index 736296517db..8590fd6c5e7 100644 --- a/backend/ee/danswer/server/enterprise_settings/api.py +++ b/backend/ee/danswer/server/enterprise_settings/api.py @@ -1,14 +1,24 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response +from fastapi import status from fastapi import UploadFile from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.auth.users import current_user +from danswer.auth.users import get_user_manager +from danswer.auth.users import UserManager from danswer.db.engine import get_session from danswer.db.models import User from danswer.file_store.file_store import get_default_file_store +from danswer.utils.logger import setup_logger from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload from ee.danswer.server.enterprise_settings.models import EnterpriseSettings from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME @@ -18,10 +28,117 @@ from ee.danswer.server.enterprise_settings.store import store_analytics_script from ee.danswer.server.enterprise_settings.store import store_settings from ee.danswer.server.enterprise_settings.store import upload_logo +from shared_configs.configs import CUSTOM_REFRESH_URL admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") +logger = setup_logger() + + +def mocked_refresh_token() -> dict: + """ + This function mocks the response from a token refresh endpoint. + It generates a mock access token, refresh token, and user information + with an expiration time set to 1 hour from now. + This is useful for testing or development when the actual refresh endpoint is not available. + """ + mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000) + data = { + "access_token": "asdf Mock access token", + "refresh_token": "asdf Mock refresh token", + "session": {"exp": mock_exp}, + "userinfo": { + "sub": "Mock email", + "familyName": "Mock name", + "givenName": "Mock name", + "fullName": "Mock name", + "userId": "Mock User ID", + "email": "test_email@danswer.ai", + }, + } + return data + + +@basic_router.get("/refresh-token") +async def refresh_access_token( + user: User = Depends(current_user), + user_manager: UserManager = Depends(get_user_manager), +) -> None: + # return + if CUSTOM_REFRESH_URL is None: + logger.error( + "Custom refresh URL is not set and client is attempting to custom refresh" + ) + raise HTTPException( + status_code=500, + detail="Custom refresh URL is not set", + ) + + try: + async with httpx.AsyncClient() as client: + logger.debug(f"Sending request to custom refresh URL for user {user.id}") + access_token = user.oauth_accounts[0].access_token + + response = await client.get( + CUSTOM_REFRESH_URL, + params={"info": "json", "access_token_refresh_interval": 3600}, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + data = response.json() + + # NOTE: Here is where we can mock the response + # data = mocked_refresh_token() + + logger.debug(f"Received response from Meechum auth URL for user {user.id}") + + # Extract new tokens + new_access_token = data["access_token"] + new_refresh_token = data["refresh_token"] + + new_expiry = datetime.fromtimestamp( + data["session"]["exp"] / 1000, tz=timezone.utc + ) + expires_at_timestamp = int(new_expiry.timestamp()) + + logger.debug(f"Access token has been refreshed for user {user.id}") + + await user_manager.oauth_callback( + oauth_name="custom", + access_token=new_access_token, + account_id=data["userinfo"]["userId"], + account_email=data["userinfo"]["email"], + expires_at=expires_at_timestamp, + refresh_token=new_refresh_token, + associate_by_email=True, + ) + + logger.info(f"Successfully refreshed tokens for user {user.id}") + + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + logger.warning(f"Full authentication required for user {user.id}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Full authentication required", + ) + logger.error( + f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to refresh token", + ) + except Exception as e: + logger.error( + f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}" + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred", + ) + @admin_router.put("") def put_settings( diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index f5698d6a6ff..3657e762667 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -70,3 +70,5 @@ "passage_prefix", "query_prefix", ] + +CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token" diff --git a/web/Dockerfile b/web/Dockerfile index 0e0ab74ccf3..88aa47a4915 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -58,6 +58,7 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T ARG NEXT_PUBLIC_DISABLE_LOGOUT ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT} + RUN npx next build #RUN npm run build diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index a969fed0e70..6499cde5fb4 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -91,7 +91,6 @@ import FunctionalHeader from "@/components/chat_search/Header"; import { useSidebarVisibility } from "@/components/chat_search/hooks"; import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants"; import FixedLogo from "./shared_chat_search/FixedLogo"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { SetDefaultModelModal } from "./modal/SetDefaultModelModal"; import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal"; import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown"; @@ -1559,7 +1558,6 @@ export function ChatPage({ setDocumentSelection((documentSelection) => !documentSelection); setShowDocSidebar(false); }; - const secondsUntilExpiration = getSecondsUntilExpiration(user); interface RegenerationRequest { messageId: number; @@ -1579,7 +1577,7 @@ export function ChatPage({ return ( <> - + {/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit. Only used in the EE version of the app. */} {popup} diff --git a/web/src/app/layout.tsx b/web/src/app/layout.tsx index 48127065a2e..a438b6193ea 100644 --- a/web/src/app/layout.tsx +++ b/web/src/app/layout.tsx @@ -3,7 +3,6 @@ import "./globals.css"; import { fetchEnterpriseSettingsSS, fetchSettingsSS, - SettingsError, } from "@/components/settings/lib"; import { CUSTOM_ANALYTICS_ENABLED, @@ -11,7 +10,7 @@ import { } from "@/lib/constants"; import { SettingsProvider } from "@/components/settings/SettingsProvider"; import { Metadata } from "next"; -import { buildClientUrl } from "@/lib/utilsSS"; +import { buildClientUrl, fetchSS } from "@/lib/utilsSS"; import { Inter } from "next/font/google"; import Head from "next/head"; import { EnterpriseSettings } from "./admin/settings/interfaces"; diff --git a/web/src/app/search/page.tsx b/web/src/app/search/page.tsx index 6f6cef8c4f0..40b4c5e53df 100644 --- a/web/src/app/search/page.tsx +++ b/web/src/app/search/page.tsx @@ -3,7 +3,6 @@ import { getAuthTypeMetadataSS, getCurrentUserSS, } from "@/lib/userSS"; -import { getSecondsUntilExpiration } from "@/lib/time"; import { redirect } from "next/navigation"; import { HealthCheckBanner } from "@/components/health/healthcheck"; import { ApiKeyModal } from "@/components/llm/ApiKeyModal"; @@ -179,11 +178,10 @@ export default async function Home() { const agenticSearchEnabled = agenticSearchToggle ? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false : false; - const secondsUntilExpiration = getSecondsUntilExpiration(user); return ( <> - + {shouldShowWelcomeModal && } diff --git a/web/src/components/health/healthcheck.tsx b/web/src/components/health/healthcheck.tsx index a8110ba8c55..2cba8be8278 100644 --- a/web/src/components/health/healthcheck.tsx +++ b/web/src/components/health/healthcheck.tsx @@ -3,29 +3,95 @@ import { errorHandlingFetcher, RedirectError } from "@/lib/fetcher"; import useSWR from "swr"; import { Modal } from "../Modal"; -import { useState } from "react"; +import { useEffect, useState } from "react"; +import { getSecondsUntilExpiration } from "@/lib/time"; +import { User } from "@/lib/types"; -export const HealthCheckBanner = ({ - secondsUntilExpiration, -}: { - secondsUntilExpiration?: number | null; -}) => { +export const HealthCheckBanner = () => { const { error } = useSWR("/api/health", errorHandlingFetcher); const [expired, setExpired] = useState(false); + const [secondsUntilExpiration, setSecondsUntilExpiration] = useState< + number | null + >(null); + const { data: user, mutate: mutateUser } = useSWR( + "/api/me", + errorHandlingFetcher + ); - if (secondsUntilExpiration !== null && secondsUntilExpiration !== undefined) { - setTimeout( - () => { - setExpired(true); - }, - secondsUntilExpiration * 1000 - 200 - ); - } + const updateExpirationTime = async () => { + const updatedUser = await mutateUser(); + + if (updatedUser) { + const seconds = getSecondsUntilExpiration(updatedUser); + setSecondsUntilExpiration(seconds); + console.debug(`Updated seconds until expiration:! ${seconds}`); + } + }; + + useEffect(() => { + updateExpirationTime(); + }, [user]); + + useEffect(() => { + if (true) { + let refreshTimeoutId: NodeJS.Timeout; + let expireTimeoutId: NodeJS.Timeout; + + const refreshToken = async () => { + try { + const response = await fetch( + "/api/enterprise-settings/refresh-token", + { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + } + ); + if (!response.ok) { + throw new Error(`HTTP error! status: ${response.status}`); + } + + console.debug("Token refresh successful"); + // Force revalidation of user data + + await mutateUser(undefined, { revalidate: true }); + updateExpirationTime(); + } catch (error) { + console.error("Error refreshing token:", error); + } + }; + + const scheduleRefreshAndExpire = () => { + if (secondsUntilExpiration !== null) { + const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000; + refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh); + + const timeUntilExpire = (secondsUntilExpiration + 10) * 1000; + expireTimeoutId = setTimeout(() => { + console.debug("Session expired. Setting expired state to true."); + setExpired(true); + }, timeUntilExpire); + } + }; + + scheduleRefreshAndExpire(); + + return () => { + clearTimeout(refreshTimeoutId); + clearTimeout(expireTimeoutId); + }; + } + }, [secondsUntilExpiration, user]); if (!error && !expired) { return null; } + console.debug( + `Rendering HealthCheckBanner. Error: ${error}, Expired: ${expired}` + ); + if (error instanceof RedirectError || expired) { return ( `${path}/:path*`); export async function middleware(request: NextRequest) {