Skip to content

Commit

Permalink
Updated refreshing (onyx-dot-app#2327)
Browse files Browse the repository at this point in the history
* clean up + add environment variables

* remove log

* update

* update api settings

* somewhat cleaner refresh functionality

* fully functional

* update settings

* validated

* remove random logs

* remove unneeded paramter + log

* move to ee + remove comments

* Cleanup unused

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
  • Loading branch information
2 people authored and rajiv chodisetti committed Oct 2, 2024
1 parent fc68b88 commit e077edb
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 23 deletions.
2 changes: 1 addition & 1 deletion backend/danswer/server/settings/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def fetch_settings(
return UserSettings(
**general_settings.model_dump(),
notifications=user_notifications,
needs_reindexing=needs_reindexing
needs_reindexing=needs_reindexing,
)


Expand Down
117 changes: 117 additions & 0 deletions backend/ee/danswer/server/enterprise_settings/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone

import httpx
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi import status
from fastapi import UploadFile
from sqlalchemy.orm import Session

from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from ee.danswer.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.danswer.server.enterprise_settings.models import EnterpriseSettings
from ee.danswer.server.enterprise_settings.store import _LOGO_FILENAME
Expand All @@ -18,10 +28,117 @@
from ee.danswer.server.enterprise_settings.store import store_analytics_script
from ee.danswer.server.enterprise_settings.store import store_settings
from ee.danswer.server.enterprise_settings.store import upload_logo
from shared_configs.configs import CUSTOM_REFRESH_URL

admin_router = APIRouter(prefix="/admin/enterprise-settings")
basic_router = APIRouter(prefix="/enterprise-settings")

logger = setup_logger()


def mocked_refresh_token() -> dict:
"""
This function mocks the response from a token refresh endpoint.
It generates a mock access token, refresh token, and user information
with an expiration time set to 1 hour from now.
This is useful for testing or development when the actual refresh endpoint is not available.
"""
mock_exp = int((datetime.now() + timedelta(hours=1)).timestamp() * 1000)
data = {
"access_token": "asdf Mock access token",
"refresh_token": "asdf Mock refresh token",
"session": {"exp": mock_exp},
"userinfo": {
"sub": "Mock email",
"familyName": "Mock name",
"givenName": "Mock name",
"fullName": "Mock name",
"userId": "Mock User ID",
"email": "test_email@danswer.ai",
},
}
return data


@basic_router.get("/refresh-token")
async def refresh_access_token(
user: User = Depends(current_user),
user_manager: UserManager = Depends(get_user_manager),
) -> None:
# return
if CUSTOM_REFRESH_URL is None:
logger.error(
"Custom refresh URL is not set and client is attempting to custom refresh"
)
raise HTTPException(
status_code=500,
detail="Custom refresh URL is not set",
)

try:
async with httpx.AsyncClient() as client:
logger.debug(f"Sending request to custom refresh URL for user {user.id}")
access_token = user.oauth_accounts[0].access_token

response = await client.get(
CUSTOM_REFRESH_URL,
params={"info": "json", "access_token_refresh_interval": 3600},
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()

# NOTE: Here is where we can mock the response
# data = mocked_refresh_token()

logger.debug(f"Received response from Meechum auth URL for user {user.id}")

# Extract new tokens
new_access_token = data["access_token"]
new_refresh_token = data["refresh_token"]

new_expiry = datetime.fromtimestamp(
data["session"]["exp"] / 1000, tz=timezone.utc
)
expires_at_timestamp = int(new_expiry.timestamp())

logger.debug(f"Access token has been refreshed for user {user.id}")

await user_manager.oauth_callback(
oauth_name="custom",
access_token=new_access_token,
account_id=data["userinfo"]["userId"],
account_email=data["userinfo"]["email"],
expires_at=expires_at_timestamp,
refresh_token=new_refresh_token,
associate_by_email=True,
)

logger.info(f"Successfully refreshed tokens for user {user.id}")

except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
logger.warning(f"Full authentication required for user {user.id}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Full authentication required",
)
logger.error(
f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to refresh token",
)
except Exception as e:
logger.error(
f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}"
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="An unexpected error occurred",
)


@admin_router.put("")
def put_settings(
Expand Down
2 changes: 2 additions & 0 deletions backend/shared_configs/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,5 @@
"passage_prefix",
"query_prefix",
]

CUSTOM_REFRESH_URL = os.environ.get("CUSTOM_REFRESH_URL") or "/settings/refresh-token"
1 change: 1 addition & 0 deletions web/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}


RUN npx next build
#RUN npm run build

Expand Down
4 changes: 1 addition & 3 deletions web/src/app/chat/ChatPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ import FunctionalHeader from "@/components/chat_search/Header";
import { useSidebarVisibility } from "@/components/chat_search/hooks";
import { SIDEBAR_TOGGLED_COOKIE_NAME } from "@/components/resizable/constants";
import FixedLogo from "./shared_chat_search/FixedLogo";
import { getSecondsUntilExpiration } from "@/lib/time";
import { SetDefaultModelModal } from "./modal/SetDefaultModelModal";
import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal";
import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown";
Expand Down Expand Up @@ -1559,7 +1558,6 @@ export function ChatPage({
setDocumentSelection((documentSelection) => !documentSelection);
setShowDocSidebar(false);
};
const secondsUntilExpiration = getSecondsUntilExpiration(user);

interface RegenerationRequest {
messageId: number;
Expand All @@ -1579,7 +1577,7 @@ export function ChatPage({

return (
<>
<HealthCheckBanner secondsUntilExpiration={secondsUntilExpiration} />
<HealthCheckBanner />
{/* ChatPopup is a custom popup that displays a admin-specified message on initial user visit.
Only used in the EE version of the app. */}
{popup}
Expand Down
3 changes: 1 addition & 2 deletions web/src/app/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ import "./globals.css";
import {
fetchEnterpriseSettingsSS,
fetchSettingsSS,
SettingsError,
} from "@/components/settings/lib";
import {
CUSTOM_ANALYTICS_ENABLED,
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
} from "@/lib/constants";
import { SettingsProvider } from "@/components/settings/SettingsProvider";
import { Metadata } from "next";
import { buildClientUrl } from "@/lib/utilsSS";
import { buildClientUrl, fetchSS } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import Head from "next/head";
import { EnterpriseSettings } from "./admin/settings/interfaces";
Expand Down
4 changes: 1 addition & 3 deletions web/src/app/search/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import {
getAuthTypeMetadataSS,
getCurrentUserSS,
} from "@/lib/userSS";
import { getSecondsUntilExpiration } from "@/lib/time";
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
Expand Down Expand Up @@ -179,11 +178,10 @@ export default async function Home() {
const agenticSearchEnabled = agenticSearchToggle
? agenticSearchToggle.value.toLocaleLowerCase() == "true" || false
: false;
const secondsUntilExpiration = getSecondsUntilExpiration(user);

return (
<>
<HealthCheckBanner secondsUntilExpiration={secondsUntilExpiration} />
<HealthCheckBanner />
{shouldShowWelcomeModal && <WelcomeModal user={user} />}
<InstantSSRAutoRefresh />

Expand Down
94 changes: 80 additions & 14 deletions web/src/components/health/healthcheck.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,95 @@
import { errorHandlingFetcher, RedirectError } from "@/lib/fetcher";
import useSWR from "swr";
import { Modal } from "../Modal";
import { useState } from "react";
import { useEffect, useState } from "react";
import { getSecondsUntilExpiration } from "@/lib/time";
import { User } from "@/lib/types";

export const HealthCheckBanner = ({
secondsUntilExpiration,
}: {
secondsUntilExpiration?: number | null;
}) => {
export const HealthCheckBanner = () => {
const { error } = useSWR("/api/health", errorHandlingFetcher);
const [expired, setExpired] = useState(false);
const [secondsUntilExpiration, setSecondsUntilExpiration] = useState<
number | null
>(null);
const { data: user, mutate: mutateUser } = useSWR<User>(
"/api/me",
errorHandlingFetcher
);

if (secondsUntilExpiration !== null && secondsUntilExpiration !== undefined) {
setTimeout(
() => {
setExpired(true);
},
secondsUntilExpiration * 1000 - 200
);
}
const updateExpirationTime = async () => {
const updatedUser = await mutateUser();

if (updatedUser) {
const seconds = getSecondsUntilExpiration(updatedUser);
setSecondsUntilExpiration(seconds);
console.debug(`Updated seconds until expiration:! ${seconds}`);
}
};

useEffect(() => {
updateExpirationTime();
}, [user]);

useEffect(() => {
if (true) {
let refreshTimeoutId: NodeJS.Timeout;
let expireTimeoutId: NodeJS.Timeout;

const refreshToken = async () => {
try {
const response = await fetch(
"/api/enterprise-settings/refresh-token",
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
}
);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}

console.debug("Token refresh successful");
// Force revalidation of user data

await mutateUser(undefined, { revalidate: true });
updateExpirationTime();
} catch (error) {
console.error("Error refreshing token:", error);
}
};

const scheduleRefreshAndExpire = () => {
if (secondsUntilExpiration !== null) {
const timeUntilRefresh = (secondsUntilExpiration + 0.5) * 1000;
refreshTimeoutId = setTimeout(refreshToken, timeUntilRefresh);

const timeUntilExpire = (secondsUntilExpiration + 10) * 1000;
expireTimeoutId = setTimeout(() => {
console.debug("Session expired. Setting expired state to true.");
setExpired(true);
}, timeUntilExpire);
}
};

scheduleRefreshAndExpire();

return () => {
clearTimeout(refreshTimeoutId);
clearTimeout(expireTimeoutId);
};
}
}, [secondsUntilExpiration, user]);

if (!error && !expired) {
return null;
}

console.debug(
`Rendering HealthCheckBanner. Error: ${error}, Expired: ${expired}`
);

if (error instanceof RedirectError || expired) {
return (
<Modal
Expand Down
1 change: 1 addition & 0 deletions web/src/lib/time.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export function getSecondsUntilExpiration(
if (!userInfo) {
return null;
}

const { oidc_expiry, current_token_created_at, current_token_expiry_length } =
userInfo;

Expand Down
1 change: 1 addition & 0 deletions web/src/middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ const eePaths = [
"/admin/whitelabeling",
"/admin/performance/custom-analytics",
];

const eePathsForMatcher = eePaths.map((path) => `${path}/:path*`);

export async function middleware(request: NextRequest) {
Expand Down

0 comments on commit e077edb

Please sign in to comment.