Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/check rerank model selected #1577

Merged
merged 9 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion web/app/components/app/chat/citation/popup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ const Popup: FC<PopupProps> = ({
data={source.index_node_hash.substring(0, 7)}
icon={<BezierCurve03 className='mr-1 w-3 h-3' />}
/>
<ProgressTooltip data={Number(source.score.toFixed(2))} />
{
source.score && (
<ProgressTooltip data={Number(source.score.toFixed(2))} />
)
}
</div>
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleValueChange = (type: string, value: string) => {
Expand All @@ -78,6 +79,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
Expand Down Expand Up @@ -270,7 +272,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
)}

<div
className='absolute z-10 bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white '
className='absolute z-[5] bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white '
style={{
borderColor: 'rgba(0, 0, 0, 0.05)',
}}
Expand Down
19 changes: 15 additions & 4 deletions web/app/components/datasets/common/check-rerank-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelVaild,
retrievalConfig,
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: BackendModel
isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
indexMethod?: string
}) => {
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined)
const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)

if (isRerankDefaultModelVaild)
return !!rerankDefaultModel

return false
})()

if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& !rerankModel
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModelSelected
)
return false

Expand All @@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel
) {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ type Props = {
}

const RetrievalMethodConfig: FC<Props> = ({
value,
value: passValue,
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext()
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
},
}
}
return passValue
})()
return (
<div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
Expand Down
5 changes: 5 additions & 0 deletions web/app/components/datasets/create/step-two/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ const StepTwo = ({
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const getCreationParams = () => {
let params
Expand All @@ -282,6 +283,7 @@ const StepTwo = ({
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig,
indexMethod: indexMethod as string,
Expand Down Expand Up @@ -359,6 +361,9 @@ const StepTwo = ({
try {
let res
const params = getCreationParams()
if (!params)
return false

setIsCreating(true)
if (!datasetId) {
res = await createFirstDocument({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ import type { FC } from 'react'
import React, { useRef, useState } from 'react'
import { useClickAway } from 'ahooks'
import { useTranslation } from 'react-i18next'
import Toast from '../../base/toast'
import { XClose } from '@/app/components/base/icons/src/vender/line/general'
import type { RetrievalConfig } from '@/types/app'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import Button from '@/app/components/base/button'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'

type Props = {
indexMethod: string
Expand All @@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({
onHide()
}, ref)

const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleSave = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
) {
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
onSave(ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig,
indexMethod,
}))
}

if (!isShow)
return null

Expand Down Expand Up @@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({
}}
>
<Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={() => onSave(retrievalConfig)} >{t('common.operation.save')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
</div>
</div>
)
Expand Down
2 changes: 2 additions & 0 deletions web/app/components/datasets/settings/form/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const Form = () => {
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleSave = async () => {
Expand All @@ -72,6 +73,7 @@ const Form = () => {
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ const config: ProviderConfig = {
'en': <CohereText className='w-[120px] h-6' />,
'zh-Hans': <CohereText className='w-[120px] h-6' />,
},
hit: {
'en': 'Rerank Model Supported',
'zh-Hans': '支持 Rerank 模型',
},
},
modal: {
key: ProviderEnum.cohere,
title: {
'en': 'cohere',
'zh-Hans': 'cohere',
'en': 'Rerank Model',
'zh-Hans': 'Rerank 模型',
},
icon: <Cohere className='w-6 h-6' />,
link: {
Expand Down
22 changes: 19 additions & 3 deletions web/app/components/header/account-setting/model-page/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context'
import I18n from '@/context/i18n'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'

const MODEL_CARD_LIST = [
config.openai,
Expand All @@ -42,6 +43,10 @@ const ModelPage = () => {
const { locale } = useContext(I18n)
const {
updateModelList,
textGenerationDefaultModel,
embeddingsDefaultModel,
speech2textDefaultModel,
rerankDefaultModel,
} = useProviderContext()
const { data: providers, mutate: mutateProviders } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
const [showModal, setShowModal] = useState(false)
Expand Down Expand Up @@ -196,11 +201,22 @@ const ModelPage = () => {
}
}

const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel

return (
<div className='relative pt-1 -mt-2'>
<div className='flex items-center justify-between mb-2 h-8'>
<div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div>
<SystemModel />
<div className={`flex items-center justify-between mb-2 h-8 ${defaultModelNotConfigured && 'px-3 bg-[#FFFAEB] rounded-lg border border-[#FEF0C7]'}`}>
{
defaultModelNotConfigured
? (
<div className='flex items-center text-xs font-medium text-gray-700'>
<AlertTriangle className='mr-1 w-3 h-3 text-[#F79009]' />
{t('common.modelProvider.notConfigured')}
</div>
)
: <div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div>
}
<SystemModel onUpdate={() => mutateProviders()} />
</div>
<div className='grid grid-cols-2 gap-4 mb-6'>
{
Expand Down
Loading
Loading