diff --git a/src/app/chat/features/Conversation/ChatList/Messages/Assistant.tsx b/src/app/chat/features/Conversation/ChatList/Messages/Assistant.tsx index 943e2f91759eb..a453f24cb0aa9 100644 --- a/src/app/chat/features/Conversation/ChatList/Messages/Assistant.tsx +++ b/src/app/chat/features/Conversation/ChatList/Messages/Assistant.tsx @@ -3,7 +3,7 @@ import { memo } from 'react'; import { useSessionStore } from '@/store/session'; import { chatSelectors } from '@/store/session/selectors'; -import { isFunctionMessage } from '@/utils/message'; +import { isFunctionMessageAtStart } from '@/utils/message'; import FunctionCall from '../Plugins/FunctionCall'; import { DefaultMessage } from './Default'; @@ -12,7 +12,8 @@ export const AssistantMessage: RenderMessage = memo( ({ id, plugin, function_call, content, ...props }) => { const genFunctionCallProps = useSessionStore(chatSelectors.getFunctionMessageParams); - if (!isFunctionMessage(content)) return ; + if (!isFunctionMessageAtStart(content)) + return ; const fcProps = genFunctionCallProps({ content, function_call, id, plugin }); diff --git a/src/const/plugin.ts b/src/const/plugin.ts index 8627ecba9a8f3..5194a6f12873a 100644 --- a/src/const/plugin.ts +++ b/src/const/plugin.ts @@ -1 +1 @@ -export const PLUGIN_SCHEMA_SEPARATOR = '--__--'; +export const PLUGIN_SCHEMA_SEPARATOR = '____'; diff --git a/src/store/plugin/selectors.ts b/src/store/plugin/selectors.ts index 663fafa701ac6..0a4398d6398b5 100644 --- a/src/store/plugin/selectors.ts +++ b/src/store/plugin/selectors.ts @@ -18,11 +18,17 @@ const enabledSchema = return enabledPlugins.includes(p.identifier); }) .flatMap((manifest) => - manifest.api.map((m) => ({ - ...m, + manifest.api.map((m) => { + const pluginType = manifest.type ? `${PLUGIN_SCHEMA_SEPARATOR + manifest.type}` : ''; + // 将插件的 identifier 作为前缀,避免重复 - name: manifest.identifier + PLUGIN_SCHEMA_SEPARATOR + m.name, - })), + const apiName = manifest.identifier + PLUGIN_SCHEMA_SEPARATOR + m.name + pluginType; + + return { + ...m, + name: apiName, + }; + }), ); return uniqBy(list, 'name'); diff --git a/src/store/session/slices/chat/actions/index.ts b/src/store/session/slices/chat/actions/index.ts index 9dfa9eacd6523..47e1bb093f3bb 100644 --- a/src/store/session/slices/chat/actions/index.ts +++ b/src/store/session/slices/chat/actions/index.ts @@ -3,6 +3,7 @@ import { StateCreator } from 'zustand/vanilla'; import { SessionStore } from '@/store/session'; import { ChatMessageAction, chatMessage } from './message'; +import { ChatPluginAction, chatPlugin } from './plugin'; import { ShareAction, chatShare } from './share'; import { ChatTopicAction, chatTopic } from './topic'; import { ChatTranslateAction, chatTranslate } from './translate'; @@ -14,7 +15,8 @@ export interface ChatAction extends ChatTopicAction, ChatMessageAction, ShareAction, - ChatTranslateAction {} + ChatTranslateAction, + ChatPluginAction {} export const createChatSlice: StateCreator< SessionStore, @@ -26,4 +28,5 @@ export const createChatSlice: StateCreator< ...chatMessage(...params), ...chatShare(...params), ...chatTranslate(...params), + ...chatPlugin(...params), }); diff --git a/src/store/session/slices/chat/actions/message.ts b/src/store/session/slices/chat/actions/message.ts index 2e3a777dc3557..72fbbf78fae2b 100644 --- a/src/store/session/slices/chat/actions/message.ts +++ b/src/store/session/slices/chat/actions/message.ts @@ -1,15 +1,12 @@ -import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk'; import { template } from 'lodash-es'; import { StateCreator } from 'zustand/vanilla'; import { LOADING_FLAT } from '@/const/message'; -import { PLUGIN_SCHEMA_SEPARATOR } from '@/const/plugin'; import { fetchChatModel } from '@/services/chatModel'; -import { fetchPlugin } from '@/services/plugin'; import { SessionStore } from '@/store/session'; -import { ChatMessage, OpenAIFunctionCall } from '@/types/chatMessage'; +import { ChatMessage } from '@/types/chatMessage'; import { fetchSSE } from '@/utils/fetch'; -import { isFunctionMessage } from '@/utils/message'; +import { isFunctionMessageAtStart, testFunctionMessageAtEnd } from '@/utils/message'; import { setNamespace } from '@/utils/storeDebug'; import { nanoid } from '@/utils/uuid'; @@ -29,6 +26,12 @@ export interface ChatMessageAction { * 清除消息 */ clearMessage: () => void; + /** + * 处理 ai 消息的核心逻辑(包含前处理与后处理) + * @param messages - 聊天消息数组 + * @param parentId - 父消息 ID,可选 + */ + coreProcessMessage: (messages: ChatMessage[], parentId: string) => Promise; /** * 删除消息 * @param id - 消息 ID @@ -40,21 +43,19 @@ export interface ChatMessageAction { */ dispatchMessage: (payload: MessageDispatch) => void; /** - * 生成消息 + * 实际获取 AI 响应 * @param messages - 聊天消息数组 * @param options - 获取 SSE 选项 */ - generateMessage: ( + fetchAIChatMessage: ( messages: ChatMessage[], assistantMessageId: string, - ) => Promise<{ isFunctionCall: boolean }>; - /** - * 实际获取 AI 响应 - * - * @param messages - 聊天消息数组 - * @param parentId - 父消息 ID,可选 - */ - realFetchAIResponse: (messages: ChatMessage[], parentId: string) => Promise; + ) => Promise<{ + content: string; + functionCallAtEnd: boolean; + functionCallContent: string; + isFunctionCall: boolean; + }>; /** * 重新发送消息 @@ -73,8 +74,6 @@ export interface ChatMessageAction { id?: string, action?: string, ) => AbortController | undefined; - - triggerFunctionCall: (id: string) => Promise; } export const chatMessage: StateCreator< @@ -93,10 +92,74 @@ export const chatMessage: StateCreator< } }, + coreProcessMessage: async (messages, userMessageId) => { + const { dispatchMessage, fetchAIChatMessage, triggerFunctionCall, activeTopicId } = get(); + + const { model } = agentSelectors.currentAgentConfig(get()); + + // 添加一个空的信息用于放置 ai 响应,注意顺序不能反 + // 因为如果顺序反了,messages 中将包含新增的 ai message + const mid = nanoid(); + + dispatchMessage({ + id: mid, + message: LOADING_FLAT, + parentId: userMessageId, + role: 'assistant', + type: 'addMessage', + }); + + // 如果有 activeTopicId,则添加 topicId + if (activeTopicId) { + dispatchMessage({ id: mid, key: 'topicId', type: 'updateMessage', value: activeTopicId }); + } + + // 为模型添加 fromModel 的额外信息 + dispatchMessage({ id: mid, key: 'fromModel', type: 'updateMessageExtra', value: model }); + + // 生成 ai message + const { isFunctionCall, content, functionCallAtEnd, functionCallContent } = + await fetchAIChatMessage(messages, mid); + + // 如果是 function,则发送函数调用方法 + if (isFunctionCall) { + let functionId = mid; + + if (functionCallAtEnd) { + // create a new separate message and remove the function call from the prev message + dispatchMessage({ + id: mid, + key: 'content', + type: 'updateMessage', + value: content.replace(functionCallContent, ''), + }); + + functionId = nanoid(); + dispatchMessage({ + id: functionId, + message: functionCallContent, + parentId: userMessageId, + role: 'assistant', + type: 'addMessage', + }); + + // also add activeTopicId + if (activeTopicId) + dispatchMessage({ + id: functionId, + key: 'topicId', + type: 'updateMessage', + value: activeTopicId, + }); + } + + triggerFunctionCall(functionId); + } + }, + deleteMessage: (id) => { get().dispatchMessage({ id, type: 'deleteMessage' }); }, - dispatchMessage: (payload) => { const { activeId } = get(); const session = sessionSelectors.currentSession(get()); @@ -106,7 +169,8 @@ export const chatMessage: StateCreator< get().dispatchSession({ chats, id: activeId, type: 'updateSessionChat' }); }, - generateMessage: async (messages, assistantId) => { + + fetchAIChatMessage: async (messages, assistantId) => { const { dispatchMessage, toggleChatLoading } = get(); const abortController = toggleChatLoading( @@ -160,6 +224,8 @@ export const chatMessage: StateCreator< let output = ''; let isFunctionCall = false; + let functionCallAtEnd = false; + let functionCallContent = ''; await fetchSSE(fetcher, { onErrorHandle: (error) => { @@ -168,15 +234,10 @@ export const chatMessage: StateCreator< onMessageHandle: (text) => { output += text; - dispatchMessage({ - id: assistantId, - key: 'content', - type: 'updateMessage', - value: output, - }); + dispatchMessage({ id: assistantId, key: 'content', type: 'updateMessage', value: output }); - // 如果是 function call - if (isFunctionMessage(output)) { + // is this message is just a function call + if (isFunctionMessageAtStart(output)) { isFunctionCall = true; } }, @@ -184,41 +245,19 @@ export const chatMessage: StateCreator< toggleChatLoading(false, undefined, t('generateMessage(end)') as string); - return { isFunctionCall }; - }, - - realFetchAIResponse: async (messages, userMessageId) => { - const { dispatchMessage, generateMessage, triggerFunctionCall, activeTopicId } = get(); - - const { model } = agentSelectors.currentAgentConfig(get()); + // also exist message like this: 请稍等,我帮您查询一下。{"function_call": {"name": "plugin-identifier____recommendClothes____standalone", "arguments": "{\n "mood": "",\n "gender": "man"\n}"}} + if (!isFunctionCall) { + const { content, valid } = testFunctionMessageAtEnd(output); - // 添加一个空的信息用于放置 ai 响应,注意顺序不能反 - // 因为如果顺序反了,messages 中将包含新增的 ai message - const mid = nanoid(); - - dispatchMessage({ - id: mid, - message: LOADING_FLAT, - parentId: userMessageId, - role: 'assistant', - type: 'addMessage', - }); - - // 如果有 activeTopicId,则添加 topicId - if (activeTopicId) { - dispatchMessage({ id: mid, key: 'topicId', type: 'updateMessage', value: activeTopicId }); + // if fc at end, replace the message + if (valid) { + isFunctionCall = true; + functionCallAtEnd = true; + functionCallContent = content; + } } - // 为模型添加 fromModel 的额外信息 - dispatchMessage({ id: mid, key: 'fromModel', type: 'updateMessageExtra', value: model }); - - // 生成 ai message - const { isFunctionCall } = await generateMessage(messages, mid); - - // 如果是 function,则发送函数调用方法 - if (isFunctionCall) { - triggerFunctionCall(mid); - } + return { content: output, functionCallAtEnd, functionCallContent, isFunctionCall }; }, resendMessage: async (messageId) => { @@ -254,17 +293,17 @@ export const chatMessage: StateCreator< if (contextMessages.length <= 0) return; - const { realFetchAIResponse } = get(); + const { coreProcessMessage } = get(); const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1); if (!latestMsg) return; - await realFetchAIResponse(contextMessages, latestMsg.id); + await coreProcessMessage(contextMessages, latestMsg.id); }, sendMessage: async (message) => { - const { dispatchMessage, realFetchAIResponse, activeTopicId } = get(); + const { dispatchMessage, coreProcessMessage, activeTopicId } = get(); const session = sessionSelectors.currentSession(get()); if (!session || !message) return; @@ -279,7 +318,7 @@ export const chatMessage: StateCreator< // Get the current messages to generate AI response const messages = chatSelectors.currentChats(get()); - await realFetchAIResponse(messages, userId); + await coreProcessMessage(messages, userId); // check activeTopic and then auto create topic const chats = chatSelectors.currentChats(get()); @@ -308,58 +347,4 @@ export const chatMessage: StateCreator< set({ abortController: undefined, chatLoadingId: undefined }, false, action); } }, - - triggerFunctionCall: async (id) => { - const { dispatchMessage, realFetchAIResponse, toggleChatLoading } = get(); - const session = sessionSelectors.currentSession(get()); - - if (!session) return; - - const message = session.chats[id]; - if (!message) return; - - let payload: PluginRequestPayload = { apiName: '', identifier: '' }; - // 识别到内容是 function_call 的情况下 - // 将 function_call 转换为 plugin request payload - if (message.content) { - const { function_call } = JSON.parse(message.content) as { - function_call: OpenAIFunctionCall; - }; - - const [identifier, apiName] = function_call.name.split(PLUGIN_SCHEMA_SEPARATOR); - payload = { apiName, arguments: function_call.arguments, identifier }; - - dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload }); - dispatchMessage({ id, key: 'content', type: 'updateMessage', value: '' }); - } else { - if (message.plugin) { - payload = message.plugin; - } - } - - if (!payload.apiName) return; - - dispatchMessage({ id, key: 'role', type: 'updateMessage', value: 'function' }); - dispatchMessage({ id, key: 'name', type: 'updateMessage', value: payload.identifier }); - dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload }); - - let data: string; - try { - const abortController = toggleChatLoading(true, id); - data = await fetchPlugin(payload, { signal: abortController?.signal }); - } catch (error) { - dispatchMessage({ id, key: 'error', type: 'updateMessage', value: error }); - - data = ''; - } - toggleChatLoading(false); - // 如果报错则结束了 - if (!data) return; - - dispatchMessage({ id, key: 'content', type: 'updateMessage', value: data }); - - const chats = chatSelectors.currentChats(get()); - - await realFetchAIResponse(chats, message.id); - }, }); diff --git a/src/store/session/slices/chat/actions/plugin.ts b/src/store/session/slices/chat/actions/plugin.ts new file mode 100644 index 0000000000000..eb1fbe2da9865 --- /dev/null +++ b/src/store/session/slices/chat/actions/plugin.ts @@ -0,0 +1,87 @@ +import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk'; +import { StateCreator } from 'zustand/vanilla'; + +import { PLUGIN_SCHEMA_SEPARATOR } from '@/const/plugin'; +import { fetchPlugin } from '@/services/plugin'; +import { SessionStore } from '@/store/session'; +import { OpenAIFunctionCall } from '@/types/chatMessage'; +import { setNamespace } from '@/utils/storeDebug'; + +import { sessionSelectors } from '../../session/selectors'; +import { chatSelectors } from '../selectors'; + +const t = setNamespace('chat/plugin'); + +/** + * 插件方法 + */ +export interface ChatPluginAction { + runPluginAutoMode: (id: string, payload: any) => Promise; + triggerFunctionCall: (id: string) => Promise; +} + +export const chatPlugin: StateCreator< + SessionStore, + [['zustand/devtools', never]], + [], + ChatPluginAction +> = (set, get) => ({ + runPluginAutoMode: async (id, payload) => { + const { dispatchMessage, coreProcessMessage, toggleChatLoading } = get(); + let data: string; + try { + const abortController = toggleChatLoading(true, id, t('fetchPlugin') as string); + data = await fetchPlugin(payload, { signal: abortController?.signal }); + } catch (error) { + dispatchMessage({ id, key: 'error', type: 'updateMessage', value: error }); + + data = ''; + } + toggleChatLoading(false); + // 如果报错则结束了 + if (!data) return; + + dispatchMessage({ id, key: 'content', type: 'updateMessage', value: data }); + + const chats = chatSelectors.currentChats(get()); + + await coreProcessMessage(chats, id); + }, + + triggerFunctionCall: async (id) => { + const { dispatchMessage, runPluginAutoMode } = get(); + const session = sessionSelectors.currentSession(get()); + + if (!session) return; + + const message = session.chats[id]; + if (!message) return; + + let payload: PluginRequestPayload = { apiName: '', identifier: '' }; + // 识别到内容是 function_call 的情况下 + // 将 function_call 转换为 plugin request payload + if (message.content) { + const { function_call } = JSON.parse(message.content) as { + function_call: OpenAIFunctionCall; + }; + + const [identifier, apiName] = function_call.name.split(PLUGIN_SCHEMA_SEPARATOR); + payload = { apiName, arguments: function_call.arguments, identifier }; + + dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload }); + dispatchMessage({ id, key: 'content', type: 'updateMessage', value: '' }); + } else { + if (message.plugin) { + payload = message.plugin; + } + } + + if (!payload.apiName) return; + + dispatchMessage({ id, key: 'role', type: 'updateMessage', value: 'function' }); + dispatchMessage({ id, key: 'name', type: 'updateMessage', value: payload.identifier }); + dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload }); + + runPluginAutoMode(id, payload); + }, +}); diff --git a/src/utils/message.ts b/src/utils/message.ts index 8c338a6a582a3..9295de007f660 100644 --- a/src/utils/message.ts +++ b/src/utils/message.ts @@ -1,9 +1,16 @@ import { FUNCTION_MESSAGE_FLAG } from '@/const/message'; -export const isFunctionMessage = (content: string) => { +export const isFunctionMessageAtStart = (content: string) => { return content.startsWith(FUNCTION_MESSAGE_FLAG); }; +export const testFunctionMessageAtEnd = (content: string) => { + const regExp = /{"function_call":.*?}}/; + const match = content.match(regExp); + + return { content: match ? match[0] : '', valid: match }; +}; + // export const createFunctionCallMessage = () => { // return [ // {