Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds additional_inputs to gr.ChatInterface #4985

Merged
merged 18 commits into from
Jul 24, 2023
Merged
5 changes: 5 additions & 0 deletions .changeset/witty-pets-rhyme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:WIP: adding `additional_inputs` to `gr.ChatInterface`
1 change: 1 addition & 0 deletions demo/chatinterface_system_prompt/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
18 changes: 18 additions & 0 deletions demo/chatinterface_system_prompt/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import gradio as gr
import time

def echo(message, history, system_prompt, tokens):
response = f"System prompt: {system_prompt}\n Message: {message}."
for i in range(min(len(response), int(tokens))):
time.sleep(0.05)
yield response[: i+1]

demo = gr.ChatInterface(echo,
additional_inputs=[
gr.Textbox("You are helpful AI.", label="System Prompt"),
gr.Slider(10, 100)
]
)

if __name__ == "__main__":
demo.queue().launch()
3 changes: 3 additions & 0 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
self.share_token = secrets.token_urlsafe(32)
self._skip_init_processing = _skip_init_processing
self.parent: BlockContext | None = None
self.is_rendered: bool = False

if render:
self.render()
Expand All @@ -127,6 +128,7 @@ def render(self):
Context.block.add(self)
if Context.root_block is not None:
Context.root_block.blocks[self._id] = self
self.is_rendered = True
if isinstance(self, components.IOComponent):
Context.root_block.temp_file_sets.append(self.temp_files)
return self
Expand All @@ -144,6 +146,7 @@ def unrender(self):
if Context.root_block is not None:
try:
del Context.root_block.blocks[self._id]
self.is_rendered = False
except KeyError:
pass
return self
Expand Down
88 changes: 62 additions & 26 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,24 @@
from __future__ import annotations

import inspect
import warnings
from typing import Callable, Generator

from gradio_client import utils as client_utils
from gradio_client.documentation import document, set_documentation_group

from gradio.blocks import Blocks
from gradio.components import (
Button,
Chatbot,
IOComponent,
Markdown,
State,
Textbox,
get_component_instance,
)
from gradio.events import Dependency, EventListenerMethod
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.layouts import Column, Group, Row
from gradio.layouts import Accordion, Column, Group, Row
from gradio.themes import ThemeClass as Theme

set_documentation_group("chatinterface")
Expand Down Expand Up @@ -53,6 +55,8 @@ def __init__(
*,
chatbot: Chatbot | None = None,
textbox: Textbox | None = None,
additional_inputs: str | IOComponent | list[str | IOComponent] | None = None,
additional_inputs_accordion_name: str = "Additional Inputs",
examples: list[str] | None = None,
cache_examples: bool | None = None,
title: str | None = None,
Expand All @@ -65,12 +69,15 @@ def __init__(
retry_btn: str | None | Button = "🔄 Retry",
undo_btn: str | None | Button = "↩️ Undo",
clear_btn: str | None | Button = "🗑️ Clear",
autofocus: bool = True,
):
"""
Parameters:
fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format.
chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created.
textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created.
additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion.
additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided.
examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input.
cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False.
title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window.
Expand All @@ -83,6 +90,7 @@ def __init__(
retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used.
undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used.
clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used.
autofocus: If True, autofocuses to the textbox when the page loads.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand All @@ -91,12 +99,6 @@ def __init__(
title=title or "Gradio",
theme=theme,
)
if len(inspect.signature(fn).parameters) != 2:
warnings.warn(
"The function to ChatInterface should take two inputs (message, history) and return a single string response.",
UserWarning,
)

self.fn = fn
self.is_generator = inspect.isgeneratorfunction(self.fn)
self.examples = examples
Expand All @@ -106,6 +108,16 @@ def __init__(
self.cache_examples = cache_examples or False
self.buttons: list[Button] = []

if additional_inputs:
if not isinstance(additional_inputs, list):
additional_inputs = [additional_inputs]
self.additional_inputs = [
get_component_instance(i, render=False) for i in additional_inputs # type: ignore
]
else:
self.additional_inputs = []
self.additional_inputs_accordion_name = additional_inputs_accordion_name

with self:
if title:
Markdown(
Expand All @@ -130,9 +142,10 @@ def __init__(
self.textbox = Textbox(
container=False,
show_label=False,
label="Message",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we adding a label here if show_label is false?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the label is used as the header of the table when examples are provided with additional inputs

placeholder="Type a message...",
scale=7,
autofocus=True,
autofocus=autofocus,
)
if submit_btn:
if isinstance(submit_btn, Button):
Expand Down Expand Up @@ -199,12 +212,24 @@ def __init__(

self.examples_handler = Examples(
examples=examples,
inputs=self.textbox,
inputs=[self.textbox] + self.additional_inputs,
outputs=self.chatbot,
fn=examples_fn,
cache_examples=self.cache_examples,
)

any_unrendered_inputs = any(
not inp.is_rendered for inp in self.additional_inputs
)
if self.additional_inputs and any_unrendered_inputs:
with Accordion(self.additional_inputs_accordion_name, open=False):
for input_component in self.additional_inputs:
if not input_component.is_rendered:
input_component.render()

# The example caching must happen after the input components have rendered
if cache_examples:
client_utils.synchronize_async(self.examples_handler.cache)

self.saved_input = State()
self.chatbot_state = State([])

Expand All @@ -230,7 +255,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand All @@ -255,7 +280,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand All @@ -280,7 +305,7 @@ def _setup_events(self) -> None:
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state],
[self.saved_input, self.chatbot_state] + self.additional_inputs,
[self.chatbot, self.chatbot_state],
api_name=False,
)
Expand Down Expand Up @@ -358,7 +383,7 @@ def _setup_api(self) -> None:

self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state],
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state],
api_name="chat",
)
Expand All @@ -373,18 +398,26 @@ def _display_input(
return history, history

def _submit_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> tuple[list[list[str | None]], list[list[str | None]]]:
history = history_with_input[:-1]
response = self.fn(message, history)
response = self.fn(message, history, *args, **kwargs)
history.append([message, response])
return history, history

def _stream_fn(
self, message: str, history_with_input: list[list[str | None]]
self,
message: str,
history_with_input: list[list[str | None]],
*args,
**kwargs,
) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]:
history = history_with_input[:-1]
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
update = history + [[message, first_response]]
Expand All @@ -397,16 +430,16 @@ def _stream_fn(
yield update, update

def _api_submit_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> tuple[str, list[list[str | None]]]:
response = self.fn(message, history)
history.append([message, response])
return response, history

def _api_stream_fn(
self, message: str, history: list[list[str | None]]
self, message: str, history: list[list[str | None]], *args, **kwargs
) -> Generator[tuple[str | None, list[list[str | None]]], None, None]:
generator = self.fn(message, history)
generator = self.fn(message, history, *args, **kwargs)
try:
first_response = next(generator)
yield first_response, history + [[message, first_response]]
Expand All @@ -415,13 +448,16 @@ def _api_stream_fn(
for response in generator:
yield response, history + [[message, response]]

def _examples_fn(self, message: str) -> list[list[str | None]]:
return [[message, self.fn(message, [])]]
def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]:
return [[message, self.fn(message, [], *args, **kwargs)]]

def _examples_stream_fn(
self, message: str
self,
message: str,
*args,
**kwargs,
) -> Generator[list[list[str | None]], None, None]:
for response in self.fn(message, []):
for response in self.fn(message, [], *args, **kwargs):
yield [[message, response]]

def _delete_prev_fn(
Expand Down
52 changes: 33 additions & 19 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
self.non_none_examples = non_none_examples
self.inputs = inputs
self.inputs_with_examples = inputs_with_examples
self.outputs = outputs
self.outputs = outputs or []
self.fn = fn
self.cache_examples = cache_examples
self._api_mode = _api_mode
Expand Down Expand Up @@ -250,23 +250,14 @@ async def create(self) -> None:
component to hold the examples"""

async def load_example(example_id):
if self.cache_examples:
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
else:
processed_example = self.non_none_processed_examples[example_id]
processed_example = self.non_none_processed_examples[example_id]
return utils.resolve_singleton(processed_example)

if Context.root_block:
if self.cache_examples and self.outputs:
targets = self.inputs_with_examples + self.outputs
else:
targets = self.inputs_with_examples
load_input_event = self.dataset.click(
self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=targets, # type: ignore
outputs=self.inputs_with_examples, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
Expand All @@ -275,7 +266,7 @@ async def load_example(example_id):
if self.run_on_click and not self.cache_examples:
if self.fn is None:
raise ValueError("Cannot run_on_click if no function is provided")
load_input_event.then(
self.load_input_event.then(
self.fn,
inputs=self.inputs, # type: ignore
outputs=self.outputs, # type: ignore
Expand All @@ -301,25 +292,24 @@ async def cache(self) -> None:

if inspect.isgeneratorfunction(self.fn):

def get_final_item(args): # type: ignore
def get_final_item(*args): # type: ignore
x = None
for x in self.fn(args): # noqa: B007 # type: ignore
for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
elif inspect.isasyncgenfunction(self.fn):

async def get_final_item(args):
async def get_final_item(*args):
x = None
async for x in self.fn(args): # noqa: B007 # type: ignore
async for x in self.fn(*args): # noqa: B007 # type: ignore
pass
return x

fn = get_final_item
else:
fn = self.fn

# create a fake dependency to process the examples and get the predictions
dependency, fn_index = Context.root_block.set_event_trigger(
event_name="fake_event",
Expand Down Expand Up @@ -352,6 +342,30 @@ async def get_final_item(args):
# Remove the "fake_event" to prevent bugs in loading interfaces from spaces
Context.root_block.dependencies.remove(dependency)
Context.root_block.fns.pop(fn_index)

# Remove the original load_input_event and replace it with one that
# also populates the input. We do it this way to to allow the cache()
# method to be called independently of the create() method
index = Context.root_block.dependencies.index(self.load_input_event)
Context.root_block.dependencies.pop(index)
Context.root_block.fns.pop(index)

async def load_example(example_id):
processed_example = self.non_none_processed_examples[
example_id
] + await self.load_from_cache(example_id)
return utils.resolve_singleton(processed_example)

self.load_input_event = self.dataset.click(
load_example,
inputs=[self.dataset],
outputs=self.inputs_with_examples + self.outputs, # type: ignore
show_progress="hidden",
postprocess=False,
queue=False,
api_name=self.api_name, # type: ignore
)

print("Caching complete\n")

async def load_from_cache(self, example_id: int) -> list[Any]:
Expand Down
Loading