Skip to content

Commit

Permalink
feat: attachment support in mo.ui.chat (#2493)
Browse files Browse the repository at this point in the history
* feat: attachment support in mo.ui.chat

* docs

* fixes

* typecheck

* Templated Prompts
  • Loading branch information
mscolnick authored and akshayka committed Oct 4, 2024
1 parent eab19a8 commit 585302f
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 48 deletions.
34 changes: 33 additions & 1 deletion docs/api/inputs/chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ chat.value
```

This returns a list of [`ChatMessage`](#marimo.ai.ChatMessage) objects, each
containing `role` and `content` attributes.
containing `role`, `content`, and optional `attachments` attributes.

```{eval-rst}
.. autoclass:: ChatMessage
Expand Down Expand Up @@ -111,6 +111,38 @@ mo.ui.chat(rag_model)
This example demonstrates how you can implement a Retrieval-Augmented
Generation (RAG) model within the chat interface.

## Templated Prompts

You can pass sample prompts to `mo.ui.chat` to allow users to select from a
list of predefined prompts. By including a `{{var}}` in the prompt, you can
dynamically insert values into the prompt; a form will be generated to allow
users to fill in the variables.

```python
mo.ui.chat(
mo.ai.llm.openai("gpt-4o"),
prompts=[
"What is the capital of France?",
"What is the capital of Germany?",
"What is the capital of {{country}}?",
],
)
```

## Including Attachments

You can allow users to upload attachments to their messages by passing an
`allow_attachments` parameter to `mo.ui.chat`.

```python
mo.ui.chat(
rag_model,
allow_attachments=["image/png", "image/jpeg"],
# or True for any attachment type
# allow_attachments=True,
)
```

## Built-in Models

marimo provides several built-in AI models that you can use with the chat UI
Expand Down
4 changes: 4 additions & 0 deletions examples/ai/chat/anthropic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __(key, mo):
system_message="You are a helpful assistant.",
api_key=key,
),
allow_attachments=[
"image/png",
"image/jpeg"
],
prompts=[
"Hello",
"How are you?",
Expand Down
4 changes: 4 additions & 0 deletions examples/ai/chat/openai_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def __(mo, openai_key):
"gpt-4o",
system_message="You are a helpful assistant.",
api_key=openai_key,
allow_attachments=[
"image/png",
"image/jpeg"
],
),
prompts=[
"Hello",
Expand Down
11 changes: 11 additions & 0 deletions frontend/src/plugins/impl/chat/ChatPlugin.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export const ChatPlugin = createPlugin<ChatMessage[]>("marimo-chatbot")
frequencyPenalty: z.number().default(0),
presencePenalty: z.number().default(0),
}),
allowAttachments: z.union([z.boolean(), z.string().array()]),
}),
)
.withFunctions<PluginFunctions>({
Expand All @@ -46,6 +47,15 @@ export const ChatPlugin = createPlugin<ChatMessage[]>("marimo-chatbot")
z.object({
role: z.enum(["system", "user", "assistant"]),
content: z.string(),
attachments: z
.array(
z.object({
name: z.string().optional(),
contentType: z.string().optional(),
url: z.string(),
}),
)
.optional(),
}),
),
config: z.object({
Expand All @@ -65,6 +75,7 @@ export const ChatPlugin = createPlugin<ChatMessage[]>("marimo-chatbot")
<Chatbot
prompts={props.data.prompts}
showConfigurationControls={props.data.showConfigurationControls}
allowAttachments={props.data.allowAttachments}
config={props.data.config}
sendPrompt={props.functions.send_prompt}
value={props.value || Arrays.EMPTY}
Expand Down
165 changes: 154 additions & 11 deletions frontend/src/plugins/impl/chat/chat-ui.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@
import { Spinner } from "@/components/icons/spinner";
import { Logger } from "@/utils/Logger";
import { type Message, useChat } from "ai/react";
import React, { useEffect } from "react";
import type { ChatMessage, ChatConfig, SendMessageRequest } from "./types";
import React, { useEffect, useRef } from "react";
import type {
ChatMessage,
ChatConfig,
SendMessageRequest,
ChatAttachment,
ChatRole,
} from "./types";
import { ErrorBanner } from "../common/error-banner";
import { Button } from "@/components/ui/button";
import { Button, buttonVariants } from "@/components/ui/button";
import {
BotMessageSquareIcon,
ClipboardIcon,
HelpCircleIcon,
SendIcon,
Trash2Icon,
DownloadIcon,
PaperclipIcon,
X,
} from "lucide-react";
import { cn } from "@/utils/cn";
import { toast } from "@/components/ui/use-toast";
Expand Down Expand Up @@ -42,6 +51,7 @@ interface Props {
prompts: string[];
config: ChatConfig;
showConfigurationControls: boolean;
allowAttachments: boolean | string[];
sendPrompt(req: SendMessageRequest): Promise<string>;
value: ChatMessage[];
setValue: (messages: ChatMessage[]) => void;
Expand All @@ -50,6 +60,8 @@ interface Props {
export const Chatbot: React.FC<Props> = (props) => {
const inputRef = React.useRef<HTMLInputElement>(null);
const [config, setConfig] = useState<ChatConfig>(props.config);
const [files, setFiles] = useState<FileList | undefined>(undefined);
const fileInputRef = useRef<HTMLInputElement>(null);

const {
messages,
Expand All @@ -67,11 +79,15 @@ export const Chatbot: React.FC<Props> = (props) => {
streamProtocol: "text",
fetch: async (_url, request) => {
const body = JSON.parse(request?.body as string) as {
messages: ChatMessage[];
messages: Message[];
};
try {
const response = await props.sendPrompt({
...body,
messages: body.messages.map((m) => ({
role: m.role as ChatRole,
content: m.content,
attachments: m.experimental_attachments,
})),
config: {
max_tokens: config.maxTokens,
temperature: config.temperature,
Expand All @@ -92,6 +108,11 @@ export const Chatbot: React.FC<Props> = (props) => {
}
},
onFinish: (message, { usage, finishReason }) => {
setFiles(undefined);

if (fileInputRef.current) {
fileInputRef.current.value = "";
}
Logger.debug("Finished streaming message:", message);
Logger.debug("Token usage:", usage);
Logger.debug("Finish reason:", finishReason);
Expand All @@ -108,12 +129,70 @@ export const Chatbot: React.FC<Props> = (props) => {
setMessages(messages.filter((message) => message.id !== id));
};

const renderAttachment = (attachment: ChatAttachment) => {
if (attachment.contentType?.startsWith("image")) {
return (
<img
src={attachment.url}
alt={attachment.name || "Attachment"}
className="object-contain rounded-sm"
width={100}
height={100}
/>
);
}

return (
<a
href={attachment.url}
target="_blank"
rel="noopener noreferrer"
className="text-link hover:underline"
>
{attachment.name || "Attachment"}
</a>
);
};

const renderMessage = (message: Message) => {
return message.role === "assistant"
? renderHTML({ html: message.content })
: message.content;
const content =
message.role === "assistant"
? renderHTML({ html: message.content })
: message.content;

const attachments = message.experimental_attachments;

return (
<>
{content}
{attachments && attachments.length > 0 && (
<div className="mt-2">
{attachments.map((attachment, index) => (
<div key={index} className="flex items-baseline gap-2 ">
{renderAttachment(attachment)}
<a
className={buttonVariants({
variant: "text",
size: "icon",
})}
href={attachment.url}
download={attachment.name}
>
<DownloadIcon className="size-3" />
</a>
</div>
))}
</div>
)}
</>
);
};

const shouldShowAttachments =
(Array.isArray(props.allowAttachments) &&
props.allowAttachments.length > 0) ||
props.allowAttachments === true;

return (
<div className="flex flex-col h-full bg-[var(--slate-1)] rounded-lg shadow border border-[var(--slate-6)]">
<div className="flex justify-end p-1">
Expand Down Expand Up @@ -197,8 +276,12 @@ export const Chatbot: React.FC<Props> = (props) => {
)}

<form
onSubmit={handleSubmit}
className="flex w-full border-t border-[var(--slate-6)] px-2 py-1"
onSubmit={(evt) => {
handleSubmit(evt, {
experimental_attachments: files,
});
}}
className="flex w-full border-t border-[var(--slate-6)] px-2 py-1 items-center"
>
{props.showConfigurationControls && (
<ConfigPopup config={config} onChange={setConfig} />
Expand All @@ -223,9 +306,69 @@ export const Chatbot: React.FC<Props> = (props) => {
ref={inputRef}
value={input}
onChange={handleInputChange}
className="flex w-full outline-none bg-transparent ml-2 text-[var(--slate-12)]"
className="flex w-full outline-none bg-transparent ml-2 text-[var(--slate-12)] mr-2"
placeholder="Type your message..."
/>
{files && files.length === 1 && (
<span
title={files[0].name}
className="text-sm text-[var(--slate-11)] truncate flex-shrink-0 w-24"
>
{files[0].name}
</span>
)}
{files && files.length > 1 && (
<span
title={[...files].map((f) => f.name).join("\n")}
className="text-sm text-[var(--slate-11)] truncate flex-shrink-0"
>
{files.length} files
</span>
)}
{files && files.length > 0 && (
<Button
type="button"
variant="text"
size="sm"
onClick={() => {
setFiles(undefined);

if (fileInputRef.current) {
fileInputRef.current.value = "";
}
}}
>
<X className="size-3" />
</Button>
)}
{shouldShowAttachments && (
<>
<Button
type="button"
variant="text"
size="sm"
onClick={() => fileInputRef.current?.click()}
>
<PaperclipIcon className="h-4" />
</Button>
<input
type="file"
ref={fileInputRef}
className="hidden"
multiple={true}
accept={
Array.isArray(props.allowAttachments)
? props.allowAttachments.join(",")
: undefined
}
onChange={(event) => {
if (event.target.files) {
setFiles(event.target.files);
}
}}
/>
</>
)}
<Button
type="submit"
disabled={isLoading || !input}
Expand Down
11 changes: 10 additions & 1 deletion frontend/src/plugins/impl/chat/types.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
/* Copyright 2024 Marimo. All rights reserved. */
export type ChatRole = "system" | "user" | "assistant";

export interface ChatMessage {
role: "system" | "user" | "assistant";
role: ChatRole;
content: string;
attachments?: ChatAttachment[];
}

export interface ChatAttachment {
name?: string;
contentType?: string;
url: string;
}

export interface SendMessageRequest {
Expand Down
3 changes: 3 additions & 0 deletions marimo/_cli/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def can_read(self, name: str) -> bool:
return not is_url(name)

def read(self, name: str) -> Tuple[str, str]:
# Is directory
if os.path.isdir(name):
return "", os.path.basename(name)
with open(name, "r") as f:
content = f.read()
return content, os.path.basename(name)
Expand Down
Loading

0 comments on commit 585302f

Please sign in to comment.