Skip to content

Commit

Permalink
✨ feat: support function call at message end
Browse files Browse the repository at this point in the history
  • Loading branch information
arvinxx committed Oct 22, 2023
1 parent 94cccb2 commit c82a3dc
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { memo } from 'react';

import { useSessionStore } from '@/store/session';
import { chatSelectors } from '@/store/session/selectors';
import { isFunctionMessage } from '@/utils/message';
import { isFunctionMessageAtStart } from '@/utils/message';

import FunctionCall from '../Plugins/FunctionCall';
import { DefaultMessage } from './Default';
Expand All @@ -12,7 +12,8 @@ export const AssistantMessage: RenderMessage = memo(
({ id, plugin, function_call, content, ...props }) => {
const genFunctionCallProps = useSessionStore(chatSelectors.getFunctionMessageParams);

if (!isFunctionMessage(content)) return <DefaultMessage content={content} id={id} {...props} />;
if (!isFunctionMessageAtStart(content))
return <DefaultMessage content={content} id={id} {...props} />;

const fcProps = genFunctionCallProps({ content, function_call, id, plugin });

Expand Down
2 changes: 1 addition & 1 deletion src/const/plugin.ts
Original file line number Diff line number Diff line change
@@ -1 +1 @@
export const PLUGIN_SCHEMA_SEPARATOR = '--__--';
export const PLUGIN_SCHEMA_SEPARATOR = '____';
14 changes: 10 additions & 4 deletions src/store/plugin/selectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ const enabledSchema =
return enabledPlugins.includes(p.identifier);
})
.flatMap((manifest) =>
manifest.api.map((m) => ({
...m,
manifest.api.map((m) => {
const pluginType = manifest.type ? `${PLUGIN_SCHEMA_SEPARATOR + manifest.type}` : '';

// 将插件的 identifier 作为前缀,避免重复
name: manifest.identifier + PLUGIN_SCHEMA_SEPARATOR + m.name,
})),
const apiName = manifest.identifier + PLUGIN_SCHEMA_SEPARATOR + m.name + pluginType;

return {
...m,
name: apiName,
};
}),
);

return uniqBy(list, 'name');
Expand Down
5 changes: 4 additions & 1 deletion src/store/session/slices/chat/actions/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { StateCreator } from 'zustand/vanilla';
import { SessionStore } from '@/store/session';

import { ChatMessageAction, chatMessage } from './message';
import { ChatPluginAction, chatPlugin } from './plugin';
import { ShareAction, chatShare } from './share';
import { ChatTopicAction, chatTopic } from './topic';
import { ChatTranslateAction, chatTranslate } from './translate';
Expand All @@ -14,7 +15,8 @@ export interface ChatAction
extends ChatTopicAction,
ChatMessageAction,
ShareAction,
ChatTranslateAction {}
ChatTranslateAction,
ChatPluginAction {}

export const createChatSlice: StateCreator<
SessionStore,
Expand All @@ -26,4 +28,5 @@ export const createChatSlice: StateCreator<
...chatMessage(...params),
...chatShare(...params),
...chatTranslate(...params),
...chatPlugin(...params),
});
219 changes: 102 additions & 117 deletions src/store/session/slices/chat/actions/message.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk';
import { template } from 'lodash-es';
import { StateCreator } from 'zustand/vanilla';

import { LOADING_FLAT } from '@/const/message';
import { PLUGIN_SCHEMA_SEPARATOR } from '@/const/plugin';
import { fetchChatModel } from '@/services/chatModel';
import { fetchPlugin } from '@/services/plugin';
import { SessionStore } from '@/store/session';
import { ChatMessage, OpenAIFunctionCall } from '@/types/chatMessage';
import { ChatMessage } from '@/types/chatMessage';
import { fetchSSE } from '@/utils/fetch';
import { isFunctionMessage } from '@/utils/message';
import { isFunctionMessageAtStart, testFunctionMessageAtEnd } from '@/utils/message';
import { setNamespace } from '@/utils/storeDebug';
import { nanoid } from '@/utils/uuid';

Expand All @@ -29,6 +26,12 @@ export interface ChatMessageAction {
* 清除消息
*/
clearMessage: () => void;
/**
* 处理 ai 消息的核心逻辑(包含前处理与后处理)
* @param messages - 聊天消息数组
* @param parentId - 父消息 ID,可选
*/
coreProcessMessage: (messages: ChatMessage[], parentId: string) => Promise<void>;
/**
* 删除消息
* @param id - 消息 ID
Expand All @@ -40,21 +43,19 @@ export interface ChatMessageAction {
*/
dispatchMessage: (payload: MessageDispatch) => void;
/**
* 生成消息
* 实际获取 AI 响应
* @param messages - 聊天消息数组
* @param options - 获取 SSE 选项
*/
generateMessage: (
fetchAIChatMessage: (
messages: ChatMessage[],
assistantMessageId: string,
) => Promise<{ isFunctionCall: boolean }>;
/**
* 实际获取 AI 响应
*
* @param messages - 聊天消息数组
* @param parentId - 父消息 ID,可选
*/
realFetchAIResponse: (messages: ChatMessage[], parentId: string) => Promise<void>;
) => Promise<{
content: string;
functionCallAtEnd: boolean;
functionCallContent: string;
isFunctionCall: boolean;
}>;

/**
* 重新发送消息
Expand All @@ -73,8 +74,6 @@ export interface ChatMessageAction {
id?: string,
action?: string,
) => AbortController | undefined;

triggerFunctionCall: (id: string) => Promise<void>;
}

export const chatMessage: StateCreator<
Expand All @@ -93,10 +92,74 @@ export const chatMessage: StateCreator<
}
},

coreProcessMessage: async (messages, userMessageId) => {
const { dispatchMessage, fetchAIChatMessage, triggerFunctionCall, activeTopicId } = get();

const { model } = agentSelectors.currentAgentConfig(get());

// 添加一个空的信息用于放置 ai 响应,注意顺序不能反
// 因为如果顺序反了,messages 中将包含新增的 ai message
const mid = nanoid();

dispatchMessage({
id: mid,
message: LOADING_FLAT,
parentId: userMessageId,
role: 'assistant',
type: 'addMessage',
});

// 如果有 activeTopicId,则添加 topicId
if (activeTopicId) {
dispatchMessage({ id: mid, key: 'topicId', type: 'updateMessage', value: activeTopicId });
}

// 为模型添加 fromModel 的额外信息
dispatchMessage({ id: mid, key: 'fromModel', type: 'updateMessageExtra', value: model });

// 生成 ai message
const { isFunctionCall, content, functionCallAtEnd, functionCallContent } =
await fetchAIChatMessage(messages, mid);

// 如果是 function,则发送函数调用方法
if (isFunctionCall) {
let functionId = mid;

if (functionCallAtEnd) {
// create a new separate message and remove the function call from the prev message
dispatchMessage({
id: mid,
key: 'content',
type: 'updateMessage',
value: content.replace(functionCallContent, ''),
});

functionId = nanoid();
dispatchMessage({
id: functionId,
message: functionCallContent,
parentId: userMessageId,
role: 'assistant',
type: 'addMessage',
});

// also add activeTopicId
if (activeTopicId)
dispatchMessage({
id: functionId,
key: 'topicId',
type: 'updateMessage',
value: activeTopicId,
});
}

triggerFunctionCall(functionId);
}
},

deleteMessage: (id) => {
get().dispatchMessage({ id, type: 'deleteMessage' });
},

dispatchMessage: (payload) => {
const { activeId } = get();
const session = sessionSelectors.currentSession(get());
Expand All @@ -106,7 +169,8 @@ export const chatMessage: StateCreator<

get().dispatchSession({ chats, id: activeId, type: 'updateSessionChat' });
},
generateMessage: async (messages, assistantId) => {

fetchAIChatMessage: async (messages, assistantId) => {
const { dispatchMessage, toggleChatLoading } = get();

const abortController = toggleChatLoading(
Expand Down Expand Up @@ -160,6 +224,8 @@ export const chatMessage: StateCreator<

let output = '';
let isFunctionCall = false;
let functionCallAtEnd = false;
let functionCallContent = '';

await fetchSSE(fetcher, {
onErrorHandle: (error) => {
Expand All @@ -168,57 +234,30 @@ export const chatMessage: StateCreator<
onMessageHandle: (text) => {
output += text;

dispatchMessage({
id: assistantId,
key: 'content',
type: 'updateMessage',
value: output,
});
dispatchMessage({ id: assistantId, key: 'content', type: 'updateMessage', value: output });

// 如果是 function call
if (isFunctionMessage(output)) {
// is this message is just a function call
if (isFunctionMessageAtStart(output)) {
isFunctionCall = true;
}
},
});

toggleChatLoading(false, undefined, t('generateMessage(end)') as string);

return { isFunctionCall };
},

realFetchAIResponse: async (messages, userMessageId) => {
const { dispatchMessage, generateMessage, triggerFunctionCall, activeTopicId } = get();

const { model } = agentSelectors.currentAgentConfig(get());
// also exist message like this: 请稍等,我帮您查询一下。{"function_call": {"name": "plugin-identifier____recommendClothes____standalone", "arguments": "{\n "mood": "",\n "gender": "man"\n}"}}
if (!isFunctionCall) {
const { content, valid } = testFunctionMessageAtEnd(output);

// 添加一个空的信息用于放置 ai 响应,注意顺序不能反
// 因为如果顺序反了,messages 中将包含新增的 ai message
const mid = nanoid();

dispatchMessage({
id: mid,
message: LOADING_FLAT,
parentId: userMessageId,
role: 'assistant',
type: 'addMessage',
});

// 如果有 activeTopicId,则添加 topicId
if (activeTopicId) {
dispatchMessage({ id: mid, key: 'topicId', type: 'updateMessage', value: activeTopicId });
// if fc at end, replace the message
if (valid) {
isFunctionCall = true;
functionCallAtEnd = true;
functionCallContent = content;
}
}

// 为模型添加 fromModel 的额外信息
dispatchMessage({ id: mid, key: 'fromModel', type: 'updateMessageExtra', value: model });

// 生成 ai message
const { isFunctionCall } = await generateMessage(messages, mid);

// 如果是 function,则发送函数调用方法
if (isFunctionCall) {
triggerFunctionCall(mid);
}
return { content: output, functionCallAtEnd, functionCallContent, isFunctionCall };
},

resendMessage: async (messageId) => {
Expand Down Expand Up @@ -254,17 +293,17 @@ export const chatMessage: StateCreator<

if (contextMessages.length <= 0) return;

const { realFetchAIResponse } = get();
const { coreProcessMessage } = get();

const latestMsg = contextMessages.filter((s) => s.role === 'user').at(-1);

if (!latestMsg) return;

await realFetchAIResponse(contextMessages, latestMsg.id);
await coreProcessMessage(contextMessages, latestMsg.id);
},

sendMessage: async (message) => {
const { dispatchMessage, realFetchAIResponse, activeTopicId } = get();
const { dispatchMessage, coreProcessMessage, activeTopicId } = get();
const session = sessionSelectors.currentSession(get());
if (!session || !message) return;

Expand All @@ -279,7 +318,7 @@ export const chatMessage: StateCreator<
// Get the current messages to generate AI response
const messages = chatSelectors.currentChats(get());

await realFetchAIResponse(messages, userId);
await coreProcessMessage(messages, userId);

// check activeTopic and then auto create topic
const chats = chatSelectors.currentChats(get());
Expand Down Expand Up @@ -308,58 +347,4 @@ export const chatMessage: StateCreator<
set({ abortController: undefined, chatLoadingId: undefined }, false, action);
}
},

triggerFunctionCall: async (id) => {
const { dispatchMessage, realFetchAIResponse, toggleChatLoading } = get();
const session = sessionSelectors.currentSession(get());

if (!session) return;

const message = session.chats[id];
if (!message) return;

let payload: PluginRequestPayload = { apiName: '', identifier: '' };
// 识别到内容是 function_call 的情况下
// 将 function_call 转换为 plugin request payload
if (message.content) {
const { function_call } = JSON.parse(message.content) as {
function_call: OpenAIFunctionCall;
};

const [identifier, apiName] = function_call.name.split(PLUGIN_SCHEMA_SEPARATOR);
payload = { apiName, arguments: function_call.arguments, identifier };

dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload });
dispatchMessage({ id, key: 'content', type: 'updateMessage', value: '' });
} else {
if (message.plugin) {
payload = message.plugin;
}
}

if (!payload.apiName) return;

dispatchMessage({ id, key: 'role', type: 'updateMessage', value: 'function' });
dispatchMessage({ id, key: 'name', type: 'updateMessage', value: payload.identifier });
dispatchMessage({ id, key: 'plugin', type: 'updateMessage', value: payload });

let data: string;
try {
const abortController = toggleChatLoading(true, id);
data = await fetchPlugin(payload, { signal: abortController?.signal });
} catch (error) {
dispatchMessage({ id, key: 'error', type: 'updateMessage', value: error });

data = '';
}
toggleChatLoading(false);
// 如果报错则结束了
if (!data) return;

dispatchMessage({ id, key: 'content', type: 'updateMessage', value: data });

const chats = chatSelectors.currentChats(get());

await realFetchAIResponse(chats, message.id);
},
});
Loading

0 comments on commit c82a3dc

Please sign in to comment.