Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/use jwt in web #533

Merged
merged 12 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ venv/
ENV/
env.bak/
venv.bak/
.conda/

# Spyder project settings
.spyderproject
Expand Down
2 changes: 1 addition & 1 deletion api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def register_blueprints(app):
resources={
r"/*": {"origins": app.config['WEB_API_CORS_ALLOW_ORIGINS']}},
supports_credentials=True,
allow_headers=['Content-Type', 'Authorization'],
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH'],
expose_headers=['X-Version', 'X-Env']
)
Expand Down
2 changes: 1 addition & 1 deletion api/controllers/web/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
api = ExternalApi(bp)


from . import completion, app, conversation, message, site, saved_message, audio
from . import completion, app, conversation, message, site, saved_message, audio, passport
64 changes: 64 additions & 0 deletions api/controllers/web/passport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding:utf-8 -*-
import uuid
from controllers.web import api
from flask_restful import Resource
from flask import request
from werkzeug.exceptions import Unauthorized, NotFound
from models.model import Site, EndUser, App
from extensions.ext_database import db
from libs.passport import PassportService

class PassportResource(Resource):
"""Base resource for passport."""
def get(self):
app_id = request.headers.get('X-App-Code')
if app_id is None:
raise Unauthorized('X-App-Code header is missing.')

# get site from db and check if it is normal
site = db.session.query(Site).filter(
Site.code == app_id,
Site.status == 'normal'
).first()
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model or app_model.status != 'normal' or not app_model.enable_site:
raise NotFound()

end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=generate_session_id(),
)
db.session.add(end_user)
db.session.commit()

payload = {
"iss": site.app_id,
'sub': 'Web API Passport',
'app_id': site.app_id,
'end_user_id': end_user.id,
}

tk = PassportService().issue(payload)

return {
'access_token': tk,
}

api.add_resource(PassportResource, '/passport')

def generate_session_id():
"""
Generate a unique session ID.
"""
while True:
session_id = str(uuid.uuid4())
existing_count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()
if existing_count == 0:
return session_id
94 changes: 16 additions & 78 deletions api/controllers/web/wraps.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,48 @@
# -*- coding:utf-8 -*-
import uuid
from functools import wraps

from flask import request, session
from flask import request
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized

from extensions.ext_database import db
from models.model import App, Site, EndUser
from models.model import App, EndUser
from libs.passport import PassportService


def validate_token(view=None):
def validate_jwt_token(view=None):
def decorator(view):
@wraps(view)
def decorated(*args, **kwargs):
site = validate_and_get_site()

app_model = db.session.query(App).filter(App.id == site.app_id).first()
if not app_model:
raise NotFound()

if app_model.status != 'normal':
raise NotFound()

if not app_model.enable_site:
raise NotFound()

end_user = create_or_update_end_user_for_session(app_model)
app_model, end_user = decode_jwt_token()

return view(app_model, end_user, *args, **kwargs)
return decorated

if view:
return decorator(view)
return decorator


def validate_and_get_site():
"""
Validate and get API token.
"""
def decode_jwt_token():
auth_header = request.headers.get('Authorization')
if auth_header is None:
raise Unauthorized('Authorization header is missing.')

if ' ' not in auth_header:
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()

if auth_scheme != 'bearer':
raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')

site = db.session.query(Site).filter(
Site.code == auth_token,
Site.status == 'normal'
).first()

if not site:
decoded = PassportService().verify(tk)
app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
if not app_model:
raise NotFound()
end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
if not end_user:
raise NotFound()

return site


def create_or_update_end_user_for_session(app_model):
"""
Create or update session terminal based on session ID.
"""
if 'session_id' not in session:
session['session_id'] = generate_session_id()

session_id = session.get('session_id')
end_user = db.session.query(EndUser) \
.filter(
EndUser.session_id == session_id,
EndUser.type == 'browser'
).first()

if end_user is None:
end_user = EndUser(
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type='browser',
is_anonymous=True,
session_id=session_id
)
db.session.add(end_user)
db.session.commit()

return end_user


def generate_session_id():
"""
Generate a unique session ID.
"""
count = 1
session_id = ''
while count != 0:
session_id = str(uuid.uuid4())
count = db.session.query(EndUser) \
.filter(EndUser.session_id == session_id).count()

return session_id

return app_model, end_user

class WebApiResource(Resource):
method_decorators = [validate_token]
method_decorators = [validate_jwt_token]
20 changes: 20 additions & 0 deletions api/libs/passport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# -*- coding:utf-8 -*-
import jwt
from werkzeug.exceptions import Unauthorized
from flask import current_app
class PassportService:
def __init__(self):
self.sk = current_app.config.get('SECRET_KEY')

def issue(self, payload):
return jwt.encode(payload, self.sk, algorithm='HS256')

def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=['HS256'])
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized('Invalid token signature.')
except jwt.exceptions.DecodeError:
raise Unauthorized('Invalid token.')
except jwt.exceptions.ExpiredSignatureError:
raise Unauthorized('Token has expired.')
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,5 @@ redis~=4.5.4
openpyxl==3.1.2
chardet~=5.1.0
docx2txt==0.8
pypdfium2==4.16.0
pypdfium2==4.16.0
pyjwt~=2.6.0
19 changes: 17 additions & 2 deletions web/app/components/share/chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,26 @@ import { useContext } from 'use-context-selector'
import produce from 'immer'
import { useBoolean, useGetState } from 'ahooks'
import AppUnavailable from '../../base/app-unavailable'
import { checkOrSetAccessToken } from '../utils'
import useConversation from './hooks/use-conversation'
import s from './style.module.css'
import { ToastContext } from '@/app/components/base/toast'
import Sidebar from '@/app/components/share/chat/sidebar'
import ConfigSence from '@/app/components/share/chat/config-scence'
import Header from '@/app/components/share/header'
import { delConversation, fetchAppInfo, fetchAppParams, fetchChatList, fetchConversations, fetchSuggestedQuestions, pinConversation, sendChatMessage, stopChatMessageResponding, unpinConversation, updateFeedback } from '@/service/share'
import {
delConversation,
fetchAppInfo,
fetchAppParams,
fetchChatList,
fetchConversations,
fetchSuggestedQuestions,
pinConversation,
sendChatMessage,
stopChatMessageResponding,
unpinConversation,
updateFeedback,
} from '@/service/share'
import type { ConversationItem, SiteInfo } from '@/models/share'
import type { PromptConfig, SuggestedQuestionsAfterAnswerConfig } from '@/models/debug'
import type { Feedbacktype, IChatItem } from '@/app/components/app/chat'
Expand Down Expand Up @@ -296,7 +309,9 @@ const Main: FC<IMainProps> = ({
return fetchConversations(isInstalledApp, installedAppInfo?.id, undefined, undefined, 100)
}

const fetchInitData = () => {
const fetchInitData = async () => {
await checkOrSetAccessToken()

return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,
Expand Down
10 changes: 5 additions & 5 deletions web/app/components/share/text-generation/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { useBoolean, useClickAway, useGetState } from 'ahooks'
import { XMarkIcon } from '@heroicons/react/24/outline'
import TabHeader from '../../base/tab-header'
import Button from '../../base/button'
import { checkOrSetAccessToken } from '../utils'
import s from './style.module.css'
import RunBatch from './run-batch'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
Expand Down Expand Up @@ -76,9 +77,6 @@ const TextGeneration: FC<IMainProps> = ({
const res: any = await doFetchSavedMessage(isInstalledApp, installedAppInfo?.id)
setSavedMessages(res.data)
}
useEffect(() => {
fetchSavedMessage()
}, [])
const handleSaveMessage = async (messageId: string) => {
await saveMessage(messageId, isInstalledApp, installedAppInfo?.id)
notify({ type: 'success', message: t('common.api.saved') })
Expand Down Expand Up @@ -256,7 +254,9 @@ const TextGeneration: FC<IMainProps> = ({
setAllTaskList(newAllTaskList)
}

const fetchInitData = () => {
const fetchInitData = async () => {
await checkOrSetAccessToken()

return Promise.all([isInstalledApp
? {
app_id: installedAppInfo?.id,
Expand All @@ -267,7 +267,7 @@ const TextGeneration: FC<IMainProps> = ({
},
plan: 'basic',
}
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id)])
: fetchAppInfo(), fetchAppParams(isInstalledApp, installedAppInfo?.id), fetchSavedMessage()])
}

useEffect(() => {
Expand Down
18 changes: 18 additions & 0 deletions web/app/components/share/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { fetchAccessToken } from '@/service/share'

export const checkOrSetAccessToken = async () => {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch (e) {

}
if (!accessTokenJson[sharedToken]) {
const res = await fetchAccessToken(sharedToken)
accessTokenJson[sharedToken] = res.access_token
localStorage.setItem('token', JSON.stringify(accessTokenJson))
}
}
12 changes: 10 additions & 2 deletions web/service/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,15 @@ const baseFetch = (
const options = Object.assign({}, baseOptions, fetchOptions)
if (isPublicAPI) {
const sharedToken = globalThis.location.pathname.split('/').slice(-1)[0]
options.headers.set('Authorization', `bearer ${sharedToken}`)
const accessToken = localStorage.getItem('token') || JSON.stringify({ [sharedToken]: '' })
let accessTokenJson = { [sharedToken]: '' }
try {
accessTokenJson = JSON.parse(accessToken)
}
catch (e) {

}
options.headers.set('Authorization', `Bearer ${accessTokenJson[sharedToken]}`)
}

if (deleteContentType) {
Expand Down Expand Up @@ -194,7 +202,7 @@ const baseFetch = (
case 401: {
if (isPublicAPI) {
Toast.notify({ type: 'error', message: 'Invalid token' })
return
return bodyJson.then((data: any) => Promise.reject(data))
}
const loginUrl = `${globalThis.location.origin}/signin`
if (IS_CE_EDITION) {
Expand Down
6 changes: 6 additions & 0 deletions web/service/share.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ export const fetchSuggestedQuestions = (messageId: string, isInstalledApp: boole
export const audioToText = (url: string, isPublicAPI: boolean, body: FormData) => {
return (getAction('post', !isPublicAPI))(url, { body }, { bodyStringify: false, deleteContentType: true }) as Promise<{ text: string }>
}

export const fetchAccessToken = async (appCode: string) => {
const headers = new Headers()
headers.append('X-App-Code', appCode)
return get('/passport', { headers }) as Promise<{ access_token: string }>
}