From c1e6eb2a0e3dfdf4a8d80f42edcda6ff326e5673 Mon Sep 17 00:00:00 2001 From: Ryan Hopper-Lowe <46546486+ryanhopperlowe@users.noreply.github.com> Date: Wed, 29 Jan 2025 13:36:19 -0600 Subject: [PATCH] chore: refactor message store to use zustand (better React State Managemnt) (#1495) - reduces complexity by adding complexity Signed-off-by: Ryan Hopper-Lowe --- .../agent/shared/ToolAuthenticationDialog.tsx | 4 +- ui/admin/app/components/chat/ChatContext.tsx | 10 +- .../app/hooks/messages/useMessageStore.ts | 20 ++ .../app/hooks/messages/useThreadEvents.ts | 170 ---------------- ui/admin/app/lib/store/chat/message-store.ts | 191 ++++++++++++++++++ ui/admin/package.json | 3 +- ui/admin/pnpm-lock.yaml | 27 +++ 7 files changed, 246 insertions(+), 179 deletions(-) create mode 100644 ui/admin/app/hooks/messages/useMessageStore.ts delete mode 100644 ui/admin/app/hooks/messages/useThreadEvents.ts create mode 100644 ui/admin/app/lib/store/chat/message-store.ts diff --git a/ui/admin/app/components/agent/shared/ToolAuthenticationDialog.tsx b/ui/admin/app/components/agent/shared/ToolAuthenticationDialog.tsx index a26ec0847..110ed1d8a 100644 --- a/ui/admin/app/components/agent/shared/ToolAuthenticationDialog.tsx +++ b/ui/admin/app/components/agent/shared/ToolAuthenticationDialog.tsx @@ -15,7 +15,7 @@ import { DialogTitle, } from "~/components/ui/dialog"; import { Link } from "~/components/ui/link"; -import { useThreadEvents } from "~/hooks/messages/useThreadEvents"; +import { useInitMessageStore } from "~/hooks/messages/useMessageStore"; type AgentAuthenticationDialogProps = { threadId: Nullish; @@ -30,7 +30,7 @@ export function ToolAuthenticationDialog({ }: AgentAuthenticationDialogProps) { const { icon, label } = useToolReference(tool); - const { messages: _messages } = useThreadEvents(threadId); + const { messages: _messages } = useInitMessageStore(threadId); type ItemState = { isLoading?: boolean; diff --git a/ui/admin/app/components/chat/ChatContext.tsx b/ui/admin/app/components/chat/ChatContext.tsx index 96b56b5f3..c0f2434e9 100644 --- a/ui/admin/app/components/chat/ChatContext.tsx +++ b/ui/admin/app/components/chat/ChatContext.tsx @@ -2,17 +2,16 @@ import { ReactNode, createContext, useContext } from "react"; import { mutate } from "swr"; import { AgentIcons } from "~/lib/model/agents"; -import { Message } from "~/lib/model/messages"; import { InvokeService } from "~/lib/service/api/invokeService"; import { ThreadsService } from "~/lib/service/api/threadsService"; +import { MessageStore } from "~/lib/store/chat/message-store"; -import { useThreadEvents } from "~/hooks/messages/useThreadEvents"; +import { useInitMessageStore } from "~/hooks/messages/useMessageStore"; import { useAsync } from "~/hooks/useAsync"; type Mode = "agent" | "workflow"; -interface ChatContextType { - messages: Message[]; +interface ChatContextType extends Pick { mode: Mode; processUserMessage: (text: string) => void; abortRunningThread: () => void; @@ -20,7 +19,6 @@ interface ChatContextType { threadId: Nullish; invoke: (prompt?: string) => void; readOnly?: boolean; - isRunning: boolean; isInvoking: boolean; introductionMessage?: string; starterMessages?: string[]; @@ -73,7 +71,7 @@ export function ChatProvider({ }, }); - const { messages, isRunning } = useThreadEvents(threadId); + const { messages, isRunning } = useInitMessageStore(threadId); const abortRunningThread = () => { if (!threadId || !isRunning) return; diff --git a/ui/admin/app/hooks/messages/useMessageStore.ts b/ui/admin/app/hooks/messages/useMessageStore.ts new file mode 100644 index 000000000..c01e02687 --- /dev/null +++ b/ui/admin/app/hooks/messages/useMessageStore.ts @@ -0,0 +1,20 @@ +import { useEffect, useState } from "react"; +import { useStore } from "zustand"; + +import { createMessageStore } from "~/lib/store/chat/message-store"; + +export const useInitMessageStore = (threadId: Nullish) => { + const [storeObj] = useState(() => createMessageStore()); + const store = useStore(storeObj); + + const { init, reset } = store; + useEffect(() => { + if (!threadId) return; + + init(threadId); + + return () => reset(); + }, [init, reset, threadId]); + + return store; +}; diff --git a/ui/admin/app/hooks/messages/useThreadEvents.ts b/ui/admin/app/hooks/messages/useThreadEvents.ts deleted file mode 100644 index 47bb6f2bc..000000000 --- a/ui/admin/app/hooks/messages/useThreadEvents.ts +++ /dev/null @@ -1,170 +0,0 @@ -import { useCallback, useEffect, useState } from "react"; - -import { ChatEvent } from "~/lib/model/chatEvents"; -import { Message, promptMessage, toolCallMessage } from "~/lib/model/messages"; -import { ThreadsService } from "~/lib/service/api/threadsService"; - -export function useThreadEvents(threadId?: Nullish) { - const [messages, setMessages] = useState([]); - const [isRunning, setIsRunning] = useState(false); - - const addContent = useCallback((event: ChatEvent) => { - const { - content, - prompt, - toolCall, - runComplete, - input, - error, - runID, - contentID, - replayComplete, - time, - } = event; - - setIsRunning(!runComplete && !replayComplete); - - setMessages((prev) => { - const copy = [...prev]; - - // todo(ryanhopperlowe) can be optmized by searching from the end - const existingIndex = contentID - ? copy.findIndex((m) => m.contentID === contentID) - : -1; - - if (existingIndex !== -1) { - const existing = copy[existingIndex]; - copy[existingIndex] = { - ...existing, - text: existing.text + content, - time: existing.time || time, - }; - - return copy; - } - - if (error) { - if (error.includes("thread was aborted, cancelling run")) { - copy.push({ - sender: "agent", - text: "Message Aborted", - runId: runID, - contentID, - aborted: true, - time, - }); - - return copy; - } - - copy.push({ - sender: "agent", - text: error, - runId: runID, - error: true, - contentID, - time, - }); - return copy; - } - - if (input) { - copy.push({ - sender: "user", - text: input, - runId: runID, - contentID, - time, - }); - return copy; - } - - if (toolCall) { - return handleToolCallEvent(copy, event); - } - - if (prompt) { - copy.push(promptMessage(prompt, runID)); - return copy; - } - - if (content) { - copy.push({ - sender: "agent", - text: content, - runId: runID, - contentID, - time, - }); - return copy; - } - - return copy; - }); - }, []); - - useEffect(() => { - setMessages([]); - - if (!threadId) return; - - const source = ThreadsService.getThreadEventSource(threadId); - - let replayComplete = false; - let replayMessages: ChatEvent[] = []; - - source.addEventListener("close", source.close); - - source.addEventListener("message", (chunk) => { - const event = JSON.parse(chunk.data) as ChatEvent; - - if (event.replayComplete) { - replayComplete = true; - replayMessages.forEach(addContent); - replayMessages = []; - } - - if (!replayComplete) { - replayMessages.push(event); - return; - } - - addContent(event); - }); - - return () => { - source.close(); - setIsRunning(false); - }; - }, [threadId, addContent]); - - return { messages, isRunning }; -} - -const findIndexLastPendingToolCall = (messages: Message[]) => { - for (let i = messages.length - 1; i >= 0; i--) { - const message = messages[i]; - if (message.tools && !message.tools[0].output) { - return i; - } - } - return null; -}; - -const handleToolCallEvent = (messages: Message[], event: ChatEvent) => { - if (!event.toolCall) return messages; - - const { toolCall } = event; - if (toolCall.output) { - const index = findIndexLastPendingToolCall(messages); - if (index !== null) { - // update the found pending toolcall message (without output) - messages[index].tools = [toolCall]; - return messages; - } - } - - // otherwise add a new toolcall message - messages.push(toolCallMessage(toolCall)); - return messages; -}; diff --git a/ui/admin/app/lib/store/chat/message-store.ts b/ui/admin/app/lib/store/chat/message-store.ts new file mode 100644 index 000000000..367fb73b8 --- /dev/null +++ b/ui/admin/app/lib/store/chat/message-store.ts @@ -0,0 +1,191 @@ +import { createStore } from "zustand"; + +import { ChatEvent } from "~/lib/model/chatEvents"; +import { Message, promptMessage, toolCallMessage } from "~/lib/model/messages"; +import { ThreadsService } from "~/lib/service/api/threadsService"; + +export type MessageStore = { + messages: Message[]; + source: EventSource | null; + isRunning: boolean; + cleanupFns: (() => void)[]; + processEvent: (event: ChatEvent) => void; + init: (threadId: string) => void; + reset: () => void; +}; + +export const createMessageStore = () => { + return createStore()((set, get) => { + return { + messages: [], + cleanupFns: [], + source: null, + isRunning: false, + processEvent: handleProcessEvent, + init: handleInit, + reset: handleReset, + }; + + function handleInit(threadId: string) { + const source = ThreadsService.getThreadEventSource(threadId); + let replayComplete = false; + let replayMessages: ChatEvent[] = []; + + const handleClose = () => source.close(); + + const handleMessage = (chunk: MessageEvent): void => { + const event = JSON.parse(chunk.data) as ChatEvent; + + if (event.replayComplete) { + replayComplete = true; + replayMessages.forEach(get().processEvent); + replayMessages = []; + } + + if (!replayComplete) { + replayMessages.push(event); + return; + } + + get().processEvent(event); + }; + + source.addEventListener("close", handleClose); + source.addEventListener("message", handleMessage); + + const cleanupFns = get().cleanupFns.concat( + () => source.removeEventListener("close", handleClose), + () => source.removeEventListener("message", handleMessage) + ); + + set({ cleanupFns, source }); + } + + function handleReset() { + const { source, cleanupFns: listenerCleanupFns } = get(); + + listenerCleanupFns.forEach((fn) => fn()); + source?.close(); + + set({ + source: null, + isRunning: false, + messages: [], + cleanupFns: [], + }); + } + + function handleProcessEvent(event: ChatEvent) { + const { + content, + prompt, + toolCall, + runComplete, + input, + error, + runID, + contentID, + replayComplete, + time, + } = event; + + set({ isRunning: !runComplete && !replayComplete }); + + set((state) => { + const copy = [...state.messages]; + + const existingIndex = contentID + ? copy.findLastIndex((m) => m.contentID === contentID) + : -1; + + if (existingIndex !== -1) { + const existing = copy[existingIndex]; + copy[existingIndex] = { + ...existing, + text: existing.text + content, + time: existing.time || time, + }; + + return { messages: copy }; + } + + if (error) { + if (error.includes("thread was aborted, cancelling run")) { + copy.push({ + sender: "agent", + text: "Message Aborted", + runId: runID, + contentID, + aborted: true, + time, + }); + + return { messages: copy }; + } + + copy.push({ + sender: "agent", + text: error, + runId: runID, + error: true, + contentID, + time, + }); + return { messages: copy }; + } + + if (input) { + copy.push({ + sender: "user", + text: input, + runId: runID, + contentID, + time, + }); + return { messages: copy }; + } + + if (toolCall) { + return { messages: handleToolCallEvent(copy, event) }; + } + + if (prompt) { + copy.push(promptMessage(prompt, runID)); + return { messages: copy }; + } + + if (content) { + copy.push({ + sender: "agent", + text: content, + runId: runID, + contentID, + time, + }); + return { messages: copy }; + } + + return { messages: copy }; + }); + } + }); +}; + +const handleToolCallEvent = (messages: Message[], event: ChatEvent) => { + if (!event.toolCall) return messages; + + const { toolCall } = event; + if (toolCall.output) { + // const index = findIndexLastPendingToolCall(messages); + const index = messages.findLastIndex((m) => m.tools && !m.tools[0].output); + if (index !== -1) { + // update the found pending toolcall message (without output) + messages[index].tools = [toolCall]; + return messages; + } + } + + // otherwise add a new toolcall message + messages.push(toolCallMessage(toolCall)); + return messages; +}; diff --git a/ui/admin/package.json b/ui/admin/package.json index 72088f0ae..1814c1d66 100644 --- a/ui/admin/package.json +++ b/ui/admin/package.json @@ -72,7 +72,8 @@ "tailwind-merge": "^2.5.2", "tailwindcss-animate": "^1.0.7", "vaul": "^1.1.0", - "zod": "^3.23.8" + "zod": "^3.23.8", + "zustand": "^5.0.3" }, "devDependencies": { "@faker-js/faker": "^9.1.0", diff --git a/ui/admin/pnpm-lock.yaml b/ui/admin/pnpm-lock.yaml index ec6b8ee84..6814f6ee1 100644 --- a/ui/admin/pnpm-lock.yaml +++ b/ui/admin/pnpm-lock.yaml @@ -185,6 +185,9 @@ importers: zod: specifier: ^3.23.8 version: 3.24.1 + zustand: + specifier: ^5.0.3 + version: 5.0.3(@types/react@18.3.17)(react@18.3.1)(use-sync-external-store@1.4.0(react@18.3.1)) devDependencies: '@faker-js/faker': specifier: ^9.1.0 @@ -4009,6 +4012,24 @@ packages: zod@3.24.1: resolution: {integrity: sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==} + zustand@5.0.3: + resolution: {integrity: sha512-14fwWQtU3pH4dE0dOpdMiWjddcH+QzKIgk1cl8epwSE7yag43k/AD/m4L6+K7DytAOr9gGBe3/EXj9g7cdostg==} + engines: {node: '>=12.20.0'} + peerDependencies: + '@types/react': '>=18.0.0' + immer: '>=9.0.6' + react: '>=18.0.0' + use-sync-external-store: '>=1.2.0' + peerDependenciesMeta: + '@types/react': + optional: true + immer: + optional: true + react: + optional: true + use-sync-external-store: + optional: true + zwitch@2.0.4: resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} @@ -8328,4 +8349,10 @@ snapshots: zod@3.24.1: {} + zustand@5.0.3(@types/react@18.3.17)(react@18.3.1)(use-sync-external-store@1.4.0(react@18.3.1)): + optionalDependencies: + '@types/react': 18.3.17 + react: 18.3.1 + use-sync-external-store: 1.4.0(react@18.3.1) + zwitch@2.0.4: {}