Skip to content

Commit

Permalink
Allow image uploads to gr.load_chat (#10345)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* changes

* changes

* chagens

* changes

* changes

* chagnges

* changes

* changes

* Update gradio/external.py

* changes

* simplify tests

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
  • Loading branch information
4 people authored Feb 3, 2025
1 parent 3750082 commit 39f0c23
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .changeset/brown-insects-say.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Allow image uploads to gr.load_chat
177 changes: 172 additions & 5 deletions gradio/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from gradio_client import Client
from gradio_client.client import Endpoint
from gradio_client.documentation import document
from gradio_client.utils import encode_url_or_file_to_base64
from packaging import version

import gradio
from gradio import components, external_utils, utils
from gradio.components.multimodal_textbox import MultimodalValue
from gradio.context import Context
from gradio.exceptions import (
GradioVersionIncompatibleError,
Expand All @@ -31,6 +33,7 @@
if TYPE_CHECKING:
from gradio.blocks import Blocks
from gradio.chat_interface import ChatInterface
from gradio.components.chatbot import MessageDict
from gradio.interface import Interface


Expand Down Expand Up @@ -586,12 +589,146 @@ def fn(*data):
return interface


TEXT_FILE_EXTENSIONS = (
".doc",
".docx",
".rtf",
".epub",
".odt",
".odp",
".pptx",
".txt",
".md",
".py",
".ipynb",
".js",
".jsx",
".html",
".css",
".java",
".cs",
".php",
".c",
".cc",
".cpp",
".cxx",
".cts",
".h",
".hh",
".hpp",
".rs",
".R",
".Rmd",
".swift",
".go",
".rb",
".kt",
".kts",
".ts",
".tsx",
".m",
".mm",
".mts",
".scala",
".dart",
".lua",
".pl",
".pm",
".t",
".sh",
".bash",
".zsh",
".bat",
".coffee",
".csv",
".log",
".ini",
".cfg",
".config",
".json",
".proto",
".yaml",
".yml",
".toml",
".sql",
)
IMAGE_FILE_EXTENSIONS = (".png", ".jpg", ".jpeg", ".gif", ".webp")


def format_conversation(
history: list[MessageDict], new_message: str | MultimodalValue
) -> list[dict]:
conversation = []
for message in history:
if isinstance(message["content"], str):
conversation.append(
{"role": message["role"], "content": message["content"]}
)
elif isinstance(message["content"], tuple):
image_message = {
"role": message["role"],
"content": [
{
"type": "image_url",
"image_url": {
"url": encode_url_or_file_to_base64(message["content"][0])
},
}
],
}
conversation.append(image_message)
else:
raise ValueError(
f"Invalid message format: {message['content']}. Messages must be either strings or tuples."
)
if isinstance(new_message, str):
text = new_message
files = []
else:
text = new_message.get("text", None)
files = new_message.get("files", [])
image_files, text_encoded = [], []
for file in files:
if file.lower().endswith(TEXT_FILE_EXTENSIONS):
text_encoded.append(file)
else:
image_files.append(file)

for image in image_files:
conversation.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": encode_url_or_file_to_base64(image)},
}
],
}
)
if text or text_encoded:
text = text or ""
text += "\n".join(
[
f"\n## {Path(file).name}\n{Path(file).read_text()}"
for file in text_encoded
]
)
conversation.append(
{"role": "user", "content": [{"type": "text", "text": text}]}
)
return conversation


@document()
def load_chat(
base_url: str,
model: str,
token: str | None = None,
*,
file_types: Literal["text_encoded", "image"]
| list[Literal["text_encoded", "image"]]
| None = "text_encoded",
system_message: str | None = None,
streaming: bool = True,
**kwargs,
Expand All @@ -602,9 +739,19 @@ def load_chat(
base_url: The base URL of the endpoint, e.g. "http://localhost:11434/v1/"
model: The name of the model you are loading, e.g. "llama3.2"
token: The API token or a placeholder string if you are using a local model, e.g. "ollama"
file_types: The file types allowed to be uploaded by the user. "text_encoded" allows uploading any text-encoded file (which is simply appended to the prompt), and "image" adds image upload support. Set to None to disable file uploads.
system_message: The system message to use for the conversation, if any.
streaming: Whether the response should be streamed.
kwargs: Additional keyword arguments to pass into ChatInterface for customization.
Example:
import gradio as gr
gr.load_chat(
"http://localhost:11434/v1/",
model="qwen2.5",
token="***",
file_types=["text_encoded", "image"],
system_message="You are a silly assistant.",
).launch()
"""
try:
from openai import OpenAI
Expand All @@ -618,29 +765,32 @@ def load_chat(
start_message = (
[{"role": "system", "content": system_message}] if system_message else []
)
file_types = utils.none_or_singleton_to_list(file_types)

def open_api(message: str, history: list | None) -> str | None:
def open_api(message: str | MultimodalValue, history: list | None) -> str | None:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
conversation = format_conversation(history, message)
return (
client.chat.completions.create(
model=model,
messages=history + [{"role": "user", "content": message}],
messages=conversation, # type: ignore
)
.choices[0]
.message.content
)

def open_api_stream(
message: str, history: list | None
message: str | MultimodalValue, history: list | None
) -> Generator[str, None, None]:
history = history or start_message
if len(history) > 0 and isinstance(history[0], (list, tuple)):
history = ChatInterface._tuples_to_messages(history)
conversation = format_conversation(history, message)
stream = client.chat.completions.create(
model=model,
messages=history + [{"role": "user", "content": message}],
messages=conversation, # type: ignore
stream=True,
)
response = ""
Expand All @@ -649,6 +799,23 @@ def open_api_stream(
response += chunk.choices[0].delta.content
yield response

supported_extensions = []
for file_type in file_types:
if file_type == "text_encoded":
supported_extensions += TEXT_FILE_EXTENSIONS
elif file_type == "image":
supported_extensions += IMAGE_FILE_EXTENSIONS
else:
raise ValueError(
f"Invalid file type: {file_type}. Must be 'text_encoded' or 'image'."
)

return ChatInterface(
open_api_stream if streaming else open_api, type="messages", **kwargs
open_api_stream if streaming else open_api,
type="messages",
multimodal=bool(file_types),
textbox=gradio.MultimodalTextbox(file_types=supported_extensions)
if file_types
else None,
**kwargs,
)
6 changes: 3 additions & 3 deletions guides/05_chatbots/01_creating-a-chatbot-fast.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ $ pip install --upgrade gradio

## Note for OpenAI-API compatible endpoints

If you have a chat server serving an OpenAI-API compatible endpoint (e.g. Ollama), you can spin up a ChatInterface in a single line of Python. First, also run `pip install openai`. Then, with your own URL, model, and optional token:
If you have a chat server serving an OpenAI-API compatible endpoint (such as Ollama), you can spin up a ChatInterface in a single line of Python. First, also run `pip install openai`. Then, with your own URL, model, and optional token:

```python
import gradio as gr

gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token="ollama").launch()
gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token="***").launch()
```

If you have your own model, keep reading to see how to create an application around any chat model in Python!
Read about `gr.load_chat` in [the docs](https://www.gradio.app/docs/gradio/load_chat). If you have your own model, keep reading to see how to create an application around any chat model in Python!

## Defining a chat function

Expand Down
36 changes: 36 additions & 0 deletions test/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,39 @@ def mock_src(name: str, token: str | None, **kwargs) -> gr.Blocks:
)

assert isinstance(result, gr.Blocks)


@patch("openai.OpenAI")
def test_load_chat_basic(mock_openai):
mock_client = MagicMock()
mock_client.chat.completions.create.return_value.choices[
0
].message.content = "Hello human!"
mock_openai.return_value = mock_client

chat = gr.load_chat(
"http://fake-api.com/v1",
model="test-model",
token="fake-token",
streaming=False,
)
response = chat.fn("Hi AI!", None)
assert response == "Hello human!"


@patch("openai.OpenAI")
def test_load_chat_with_streaming(mock_openai):
mock_client = MagicMock()
mock_stream = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hello"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=" World"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content="!"))]),
]
mock_client.chat.completions.create.return_value = mock_stream
mock_openai.return_value = mock_client
chat = gr.load_chat(
"http://fake-api.com/v1", model="test-model", token="fake-token", streaming=True
)
response_stream = chat.fn("Hi!", None)
responses = list(response_stream)
assert responses == ["Hello", "Hello World", "Hello World!"]

0 comments on commit 39f0c23

Please sign in to comment.