From 4528b4bd4a0cccad23ee6b39310520c643429ef2 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Thu, 20 Jul 2023 14:12:53 +0300 Subject: [PATCH 01/13] adding additional inputs --- gradio/chat_interface.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 7c6bc63455e91..75bb785111a30 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -15,12 +15,14 @@ from gradio.components import ( Button, Chatbot, + IOComponent, Markdown, State, Textbox, + get_component_instance, ) from gradio.helpers import create_examples as Examples # noqa: N812 -from gradio.layouts import Group, Row +from gradio.layouts import Accordion, Group, Row from gradio.themes import ThemeClass as Theme set_documentation_group("chatinterface") @@ -52,6 +54,7 @@ def __init__( *, chatbot: Chatbot | None = None, textbox: Textbox | None = None, + additional_inputs: str | IOComponent | list[str | IOComponent] | None, examples: list[str] | None = None, cache_examples: bool | None = None, title: str | None = None, @@ -102,6 +105,15 @@ def __init__( self.cache_examples = cache_examples or False self.buttons: list[Button] = [] + if additional_inputs: + if not isinstance(additional_inputs, list): + additional_inputs = [additional_inputs] + self.additional_inputs = [ + get_component_instance(i, render=False) for i in additional_inputs # type: ignore + ] + else: + self.additional_inputs = None + with self: if title: Markdown( @@ -176,6 +188,11 @@ def __init__( cache_examples=self.cache_examples, ) + if self.additional_inputs: + with Accordion("Additional Inputs"): + for input_component in self.additional_inputs: + input_component.render() + self.saved_input = State() self._setup_events() From 8c38d935562a5c1d5caf6be2dc341b9cfa3e8c83 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 14:07:53 +0300 Subject: [PATCH 02/13] add param --- gradio/chat_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index df6f33ca08407..5691c5710f1e8 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -24,7 +24,6 @@ from gradio.events import Dependency, EventListenerMethod from gradio.helpers import create_examples as Examples # noqa: N812 from gradio.layouts import Accordion, Column, Group, Row -from gradio.layouts import Column, Group, Row from gradio.themes import ThemeClass as Theme set_documentation_group("chatinterface") @@ -57,6 +56,7 @@ def __init__( chatbot: Chatbot | None = None, textbox: Textbox | None = None, additional_inputs: str | IOComponent | list[str | IOComponent] | None, + additional_inputs_accordion_name: str = "Additional Inputs", examples: list[str] | None = None, cache_examples: bool | None = None, title: str | None = None, From ed448919ff6b8d324e106fb233a474fb337a0ef4 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 14:09:30 +0300 Subject: [PATCH 03/13] guide --- guides/04_chatbots/01_creating-a-chatbot-fast.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guides/04_chatbots/01_creating-a-chatbot-fast.md b/guides/04_chatbots/01_creating-a-chatbot-fast.md index 0b307537f9c26..5afe07eefd356 100644 --- a/guides/04_chatbots/01_creating-a-chatbot-fast.md +++ b/guides/04_chatbots/01_creating-a-chatbot-fast.md @@ -87,7 +87,7 @@ def slow_echo(message, history): gr.ChatInterface(slow_echo).queue().launch() ``` -Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. +Notice that we've [enabled queuing](/guides/key-features#queuing), which is required to use generator functions. While the response is streaming, the "Submit" button turns into a "Stop" button that can be used to stop the generator function. You can customize the appearance and behavior of the "Stop" button using the `stop_btn` parameter. ## Customizing your chatbot From 835c32385a580ec3ebcba28eba2a3b1574048cb7 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 19:22:46 +0300 Subject: [PATCH 04/13] add is_rendered --- gradio/blocks.py | 3 +++ gradio/chat_interface.py | 8 ++++++-- test/test_blocks.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/gradio/blocks.py b/gradio/blocks.py index f441d39d9b981..8386ecf21fd5c 100644 --- a/gradio/blocks.py +++ b/gradio/blocks.py @@ -110,6 +110,7 @@ def __init__( self.share_token = secrets.token_urlsafe(32) self._skip_init_processing = _skip_init_processing self.parent: BlockContext | None = None + self.is_rendered: bool = False if render: self.render() @@ -127,6 +128,7 @@ def render(self): Context.block.add(self) if Context.root_block is not None: Context.root_block.blocks[self._id] = self + self.is_rendered = True if isinstance(self, components.IOComponent): Context.root_block.temp_file_sets.append(self.temp_files) return self @@ -144,6 +146,7 @@ def unrender(self): if Context.root_block is not None: try: del Context.root_block.blocks[self._id] + self.is_rendered = False except KeyError: pass return self diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 5691c5710f1e8..d327e233d574b 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -75,6 +75,8 @@ def __init__( fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format. chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created. textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created. + additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed underthe chatbot, in an accordion. + additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided. examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input. cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False. title: a title for the interface; if provided, appears above chatbot in large font. Also used as the tab title when opened in a browser window. @@ -118,6 +120,7 @@ def __init__( ] else: self.additional_inputs = None + self.additional_inputs_accordion_name = additional_inputs_accordion_name with self: if title: @@ -219,9 +222,10 @@ def __init__( ) if self.additional_inputs: - with Accordion("Additional Inputs"): + with Accordion(self.additional_inputs_accordion_name): for input_component in self.additional_inputs: - input_component.render() + if not input_component.is_rendered(): + input_component.render() self.saved_input = State() self.chatbot_state = State([]) diff --git a/test/test_blocks.py b/test/test_blocks.py index a61423137edd9..4126a1fb67c92 100644 --- a/test/test_blocks.py +++ b/test/test_blocks.py @@ -1221,6 +1221,37 @@ def test_no_error(self): io3 = io2.render() assert io2 == io3 + def test_is_rendered(self): + t = gr.Textbox() + with gr.Blocks(): + pass + assert not t.is_rendered + + t = gr.Textbox() + with gr.Blocks(): + t.render() + assert t.is_rendered + + t = gr.Textbox() + with gr.Blocks(): + t.render() + t.unrender() + assert not t.is_rendered + + with gr.Blocks(): + t = gr.Textbox() + assert t.is_rendered + + with gr.Blocks(): + t = gr.Textbox() + with gr.Blocks(): + pass + assert t.is_rendered + + t = gr.Textbox() + gr.Interface(lambda x: x, "textbox", t) + assert t.is_rendered + def test_no_error_if_state_rendered_multiple_times(self): state = gr.State("") gr.TabbedInterface( From 98772c617916e15f180a9580484b3c74be376035 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 21:15:05 +0300 Subject: [PATCH 05/13] add demo --- demo/chatinterface_system_prompt/run.ipynb | 1 + demo/chatinterface_system_prompt/run.py | 18 +++++++ gradio/chat_interface.py | 60 ++++++++++++---------- test/test_chat_interface.py | 4 -- 4 files changed, 52 insertions(+), 31 deletions(-) create mode 100644 demo/chatinterface_system_prompt/run.ipynb create mode 100644 demo/chatinterface_system_prompt/run.py diff --git a/demo/chatinterface_system_prompt/run.ipynb b/demo/chatinterface_system_prompt/run.ipynb new file mode 100644 index 0000000000000..c92e7ea5d9082 --- /dev/null +++ b/demo/chatinterface_system_prompt/run.ipynb @@ -0,0 +1 @@ +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatinterface_system_prompt"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import time\n", "\n", "def echo(message, history, system_prompt, tokens):\n", " response = f\"System prompt: {system_prompt}\\n Message: {message}.\"\n", " for i in range(min(len(response), int(tokens))):\n", " time.sleep(0.05)\n", " yield response[: i+1]\n", "\n", "demo = gr.ChatInterface(echo, \n", " additional_inputs=[\n", " gr.Textbox(\"You are helpful AI.\", label=\"System Prompt\"), \n", " gr.Slider(10, 100)\n", " ]\n", " )\n", "\n", "if __name__ == \"__main__\":\n", " demo.queue().launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatinterface_system_prompt/run.py b/demo/chatinterface_system_prompt/run.py new file mode 100644 index 0000000000000..e8b1422c4f788 --- /dev/null +++ b/demo/chatinterface_system_prompt/run.py @@ -0,0 +1,18 @@ +import gradio as gr +import time + +def echo(message, history, system_prompt, tokens): + response = f"System prompt: {system_prompt}\n Message: {message}." + for i in range(min(len(response), int(tokens))): + time.sleep(0.05) + yield response[: i+1] + +demo = gr.ChatInterface(echo, + additional_inputs=[ + gr.Textbox("You are helpful AI.", label="System Prompt"), + gr.Slider(10, 100) + ] + ) + +if __name__ == "__main__": + demo.queue().launch() \ No newline at end of file diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index d327e233d574b..d753263d534ff 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -6,7 +6,6 @@ from __future__ import annotations import inspect -import warnings from typing import Callable, Generator from gradio_client.documentation import document, set_documentation_group @@ -69,6 +68,7 @@ def __init__( retry_btn: str | None | Button = "🔄 Retry", undo_btn: str | None | Button = "↩ī¸ Undo", clear_btn: str | None | Button = "🗑ī¸ Clear", + autofocus: bool = True, ): """ Parameters: @@ -89,6 +89,7 @@ def __init__( retry_btn: Text to display on the retry button. If None, no button will be displayed. If a Button object, that button will be used. undo_btn: Text to display on the delete last button. If None, no button will be displayed. If a Button object, that button will be used. clear_btn: Text to display on the clear button. If None, no button will be displayed. If a Button object, that button will be used. + autofocus: If True, autofocuses to the textbox when the page loads. """ super().__init__( analytics_enabled=analytics_enabled, @@ -97,12 +98,6 @@ def __init__( title=title or "Gradio", theme=theme, ) - if len(inspect.signature(fn).parameters) != 2: - warnings.warn( - "The function to ChatInterface should take two inputs (message, history) and return a single string response.", - UserWarning, - ) - self.fn = fn self.is_generator = inspect.isgeneratorfunction(self.fn) self.examples = examples @@ -119,7 +114,7 @@ def __init__( get_component_instance(i, render=False) for i in additional_inputs # type: ignore ] else: - self.additional_inputs = None + self.additional_inputs = [] self.additional_inputs_accordion_name = additional_inputs_accordion_name with self: @@ -148,7 +143,7 @@ def __init__( show_label=False, placeholder="Type a message...", scale=7, - autofocus=True, + autofocus=autofocus, ) if submit_btn: if isinstance(submit_btn, Button): @@ -215,16 +210,16 @@ def __init__( self.examples_handler = Examples( examples=examples, - inputs=self.textbox, + inputs=self.textbox + self.additional_inputs, outputs=self.chatbot, fn=examples_fn, cache_examples=self.cache_examples, ) if self.additional_inputs: - with Accordion(self.additional_inputs_accordion_name): + with Accordion(self.additional_inputs_accordion_name, open=False): for input_component in self.additional_inputs: - if not input_component.is_rendered(): + if not input_component.is_rendered: input_component.render() self.saved_input = State() @@ -252,7 +247,7 @@ def _setup_events(self) -> None: ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -277,7 +272,7 @@ def _setup_events(self) -> None: ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -302,7 +297,7 @@ def _setup_events(self) -> None: ) .then( submit_fn, - [self.saved_input, self.chatbot_state], + [self.saved_input, self.chatbot_state] + self.additional_inputs, [self.chatbot, self.chatbot_state], api_name=False, ) @@ -380,7 +375,7 @@ def _setup_api(self) -> None: self.fake_api_btn.click( api_fn, - [self.textbox, self.chatbot_state], + [self.textbox, self.chatbot_state] + self.additional_inputs, [self.textbox, self.chatbot_state], api_name="chat", ) @@ -395,18 +390,26 @@ def _display_input( return history, history def _submit_fn( - self, message: str, history_with_input: list[list[str | None]] + self, + message: str, + history_with_input: list[list[str | None]], + *args, + **kwargs, ) -> tuple[list[list[str | None]], list[list[str | None]]]: history = history_with_input[:-1] - response = self.fn(message, history) + response = self.fn(message, history, *args, **kwargs) history.append([message, response]) return history, history def _stream_fn( - self, message: str, history_with_input: list[list[str | None]] + self, + message: str, + history_with_input: list[list[str | None]], + *args, + **kwargs, ) -> Generator[tuple[list[list[str | None]], list[list[str | None]]], None, None]: history = history_with_input[:-1] - generator = self.fn(message, history) + generator = self.fn(message, history, *args, **kwargs) try: first_response = next(generator) update = history + [[message, first_response]] @@ -419,16 +422,16 @@ def _stream_fn( yield update, update def _api_submit_fn( - self, message: str, history: list[list[str | None]] + self, message: str, history: list[list[str | None]], *args, **kwargs ) -> tuple[str, list[list[str | None]]]: response = self.fn(message, history) history.append([message, response]) return response, history def _api_stream_fn( - self, message: str, history: list[list[str | None]] + self, message: str, history: list[list[str | None]], *args, **kwargs ) -> Generator[tuple[str | None, list[list[str | None]]], None, None]: - generator = self.fn(message, history) + generator = self.fn(message, history, *args, **kwargs) try: first_response = next(generator) yield first_response, history + [[message, first_response]] @@ -437,13 +440,16 @@ def _api_stream_fn( for response in generator: yield response, history + [[message, response]] - def _examples_fn(self, message: str) -> list[list[str | None]]: - return [[message, self.fn(message, [])]] + def _examples_fn(self, message: str, *args, **kwargs) -> list[list[str | None]]: + return [[message, self.fn(message, [], *args, **kwargs)]] def _examples_stream_fn( - self, message: str + self, + message: str, + *args, + **kwargs, ) -> Generator[list[list[str | None]], None, None]: - for response in self.fn(message, []): + for response in self.fn(message, [], *args, **kwargs): yield [[message, response]] def _delete_prev_fn( diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index b398325da0da5..50debe248389c 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -27,10 +27,6 @@ def test_no_fn(self): with pytest.raises(TypeError): gr.ChatInterface() - def test_invalid_fn_inputs(self): - with pytest.warns(UserWarning): - gr.ChatInterface(invalid_fn) - def test_configuring_buttons(self): chatbot = gr.ChatInterface(double, submit_btn=None, retry_btn=None) assert chatbot.submit_btn is None From e1c290047442372579e59314322f96be5a1e3227 Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 21:57:27 +0300 Subject: [PATCH 06/13] fixing examples --- gradio/chat_interface.py | 12 +++++++--- gradio/helpers.py | 51 +++++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index d753263d534ff..7a0dbcca9ad41 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -8,6 +8,7 @@ import inspect from typing import Callable, Generator +from gradio_client import utils as client_utils from gradio_client.documentation import document, set_documentation_group from gradio.blocks import Blocks @@ -54,7 +55,7 @@ def __init__( *, chatbot: Chatbot | None = None, textbox: Textbox | None = None, - additional_inputs: str | IOComponent | list[str | IOComponent] | None, + additional_inputs: str | IOComponent | list[str | IOComponent] | None = None, additional_inputs_accordion_name: str = "Additional Inputs", examples: list[str] | None = None, cache_examples: bool | None = None, @@ -141,6 +142,7 @@ def __init__( self.textbox = Textbox( container=False, show_label=False, + label="Message", placeholder="Type a message...", scale=7, autofocus=autofocus, @@ -210,10 +212,10 @@ def __init__( self.examples_handler = Examples( examples=examples, - inputs=self.textbox + self.additional_inputs, + inputs=[self.textbox] + self.additional_inputs, outputs=self.chatbot, fn=examples_fn, - cache_examples=self.cache_examples, + # cache_examples=cache_examples ) if self.additional_inputs: @@ -222,6 +224,10 @@ def __init__( if not input_component.is_rendered: input_component.render() + # The example caching must happen after the input components have rendered + if cache_examples: + client_utils.synchronize_async(self.examples_handler.cache) + self.saved_input = State() self.chatbot_state = State([]) diff --git a/gradio/helpers.py b/gradio/helpers.py index 660c89de70917..9032e532b9419 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -195,7 +195,7 @@ def __init__( self.non_none_examples = non_none_examples self.inputs = inputs self.inputs_with_examples = inputs_with_examples - self.outputs = outputs + self.outputs = outputs or [] self.fn = fn self.cache_examples = cache_examples self._api_mode = _api_mode @@ -250,23 +250,14 @@ async def create(self) -> None: component to hold the examples""" async def load_example(example_id): - if self.cache_examples: - processed_example = self.non_none_processed_examples[ - example_id - ] + await self.load_from_cache(example_id) - else: - processed_example = self.non_none_processed_examples[example_id] + processed_example = self.non_none_processed_examples[example_id] return utils.resolve_singleton(processed_example) if Context.root_block: - if self.cache_examples and self.outputs: - targets = self.inputs_with_examples + self.outputs - else: - targets = self.inputs_with_examples - load_input_event = self.dataset.click( + self.load_input_event = self.dataset.click( load_example, inputs=[self.dataset], - outputs=targets, # type: ignore + outputs=self.inputs_with_examples, # type: ignore show_progress="hidden", postprocess=False, queue=False, @@ -275,7 +266,7 @@ async def load_example(example_id): if self.run_on_click and not self.cache_examples: if self.fn is None: raise ValueError("Cannot run_on_click if no function is provided") - load_input_event.then( + self.load_input_event.then( self.fn, inputs=self.inputs, # type: ignore outputs=self.outputs, # type: ignore @@ -301,25 +292,24 @@ async def cache(self) -> None: if inspect.isgeneratorfunction(self.fn): - def get_final_item(args): # type: ignore + def get_final_item(*args): # type: ignore x = None - for x in self.fn(args): # noqa: B007 # type: ignore + for x in self.fn(*args): # noqa: B007 # type: ignore pass return x fn = get_final_item elif inspect.isasyncgenfunction(self.fn): - async def get_final_item(args): + async def get_final_item(*args): x = None - async for x in self.fn(args): # noqa: B007 # type: ignore + async for x in self.fn(*args): # noqa: B007 # type: ignore pass return x fn = get_final_item else: fn = self.fn - # create a fake dependency to process the examples and get the predictions dependency, fn_index = Context.root_block.set_event_trigger( event_name="fake_event", @@ -352,6 +342,29 @@ async def get_final_item(args): # Remove the "fake_event" to prevent bugs in loading interfaces from spaces Context.root_block.dependencies.remove(dependency) Context.root_block.fns.pop(fn_index) + + # Remove the original load_input_event and replace it with one that + # also populates the input. We do it this way to to allow the cache() + # method to be called independently of the create() method + index = Context.root_block.dependencies.index(self.load_input_event) + Context.root_block.dependencies.pop(index) + Context.root_block.fns.pop(index) + async def load_example(example_id): + processed_example = self.non_none_processed_examples[ + example_id + ] + await self.load_from_cache(example_id) + return utils.resolve_singleton(processed_example) + + self.load_input_event = self.dataset.click( + load_example, + inputs=[self.dataset], + outputs=self.inputs_with_examples + self.outputs, # type: ignore + show_progress="hidden", + postprocess=False, + queue=False, + api_name=self.api_name, # type: ignore + ) + print("Caching complete\n") async def load_from_cache(self, example_id: int) -> list[Any]: From 191011ec717d9caa1d848f3fce49fe99d0377b9c Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 22:11:55 +0300 Subject: [PATCH 07/13] add test --- gradio/helpers.py | 7 ++++--- test/test_chat_interface.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/gradio/helpers.py b/gradio/helpers.py index 9032e532b9419..188bd6059b0c4 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -343,18 +343,19 @@ async def get_final_item(*args): Context.root_block.dependencies.remove(dependency) Context.root_block.fns.pop(fn_index) - # Remove the original load_input_event and replace it with one that + # Remove the original load_input_event and replace it with one that # also populates the input. We do it this way to to allow the cache() - # method to be called independently of the create() method + # method to be called independently of the create() method index = Context.root_block.dependencies.index(self.load_input_event) Context.root_block.dependencies.pop(index) Context.root_block.fns.pop(index) + async def load_example(example_id): processed_example = self.non_none_processed_examples[ example_id ] + await self.load_from_cache(example_id) return utils.resolve_singleton(processed_example) - + self.load_input_event = self.dataset.click( load_example, inputs=[self.dataset], diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index 50debe248389c..710c6243d56bf 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -22,6 +22,12 @@ def count(message, history): return str(len(history)) +def echo_system_prompt_plus_message(message, history, system_prompt, tokens): + response = f"{system_prompt} {message}" + for i in range(min(len(response), int(tokens))): + yield response[: i + 1] + + class TestInit: def test_no_fn(self): with pytest.raises(TypeError): @@ -79,6 +85,19 @@ async def test_example_caching_with_streaming(self): assert prediction_hello[0][0] == ["hello", "hello"] assert prediction_hi[0][0] == ["hi", "hi"] + @pytest.mark.asyncio + async def test_example_caching_with_additional_inputs(self): + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=["textbox", "slider"], + examples=[["hello", "robot", 100], ["hi", "robot", 2]], + cache_examples=True, + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["hello", "robot hello"] + assert prediction_hi[0][0] == ["hi", "ro"] + class TestAPI: def test_get_api_info(self): @@ -100,3 +119,21 @@ def test_non_streaming_api(self, connect): with connect(chatbot) as client: result = client.predict("hello") assert result == "hello hello" + + def test_streaming_api_with_additional_inputs(self, connect): + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=["textbox", "slider"], + ).queue() + with connect(chatbot) as client: + job = client.submit("hello", "robot", 7) + wait([job]) + assert job.outputs() == [ + "r", + "ro", + "rob", + "robo", + "robot", + "robot ", + "robot h", + ] From 0cb940a04e949bcb74a19b93c358f34813bd45bb Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Fri, 21 Jul 2023 22:30:38 +0300 Subject: [PATCH 08/13] guide --- gradio/chat_interface.py | 5 ++- .../04_chatbots/01_creating-a-chatbot-fast.md | 37 ++++++++++++++++++- 2 files changed, 39 insertions(+), 3 deletions(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 7a0dbcca9ad41..ca23d7d9f95c6 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -218,7 +218,10 @@ def __init__( # cache_examples=cache_examples ) - if self.additional_inputs: + any_unrendered_inputs = any( + not inp.is_rendered for inp in self.additional_inputs + ) + if self.additional_inputs and any_unrendered_inputs: with Accordion(self.additional_inputs_accordion_name, open=False): for input_component in self.additional_inputs: if not input_component.is_rendered: diff --git a/guides/04_chatbots/01_creating-a-chatbot-fast.md b/guides/04_chatbots/01_creating-a-chatbot-fast.md index 5afe07eefd356..353dcd6d5ca01 100644 --- a/guides/04_chatbots/01_creating-a-chatbot-fast.md +++ b/guides/04_chatbots/01_creating-a-chatbot-fast.md @@ -125,11 +125,44 @@ gr.ChatInterface( ).launch() ``` +## Additional Inputs + +You may want to add additional parameters to your chatbot and expose them to your users through the Chatbot UI. For example, suppose you want to add a textbox for a system prompt, or a slider that sets the number of tokens in the chatbot's response. The `ChatInterface` class supports an `additional_inputs` parameter which can be used to add additional input components. + +The `additional_inputs` parameters accepts a component or a list of components. You can pass the component instances directly, or use their string shortcuts (e.g. `"textbox"` instead of `gr.Textbox()`). If you pass in component instances, and they have *not* already been rendered, then the components will appear underneath the chatbot (and any examples) within a `gr.Accordion()`. You can set the label of this accordion using the `additional_inputs_accordion_name` parameter. + +Here's a complete example: + +$code_chatinterface_system_prompt + +If the components you pass into the `additional_inputs` have already been rendered in a parent `gr.Blocks()`, then they will *not* be re-rendered in the accordion. This provides flexibility in deciding where to lay out the input components. In the example below, we position the `gr.Textbox()` on top of the Chatbot UI, while keeping the slider underneath. + +```python +import gradio as gr +import time + +def echo(message, history, system_prompt, tokens): + response = f"System prompt: {system_prompt}\n Message: {message}." + for i in range(min(len(response), int(tokens))): + time.sleep(0.05) + yield response[: i+1] + +with gr.Blocks() as demo: + system_prompt = gr.Textbox("You are helpful AI.", label="System Prompt") + slider = gr.Slider(10, 100, render=False) + + gr.ChatInterface( + echo, additional_inputs=[system_prompt, slider] + ) + +demo.queue().launch() +``` + If you need to create something even more custom, then its best to construct the chatbot UI using the low-level `gr.Blocks()` API. We have [a dedicated guide for that here](/guides/creating-a-custom-chatbot-with-blocks). ## Using your chatbot via an API -Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message, and will return the response, internally keeping track of the messages sent so far. +Once you've built your Gradio chatbot and are hosting it on [Hugging Face Spaces](https://hf.space) or somewhere else, then you can query it with a simple API at the `/chat` endpoint. The endpoint just expects the user's message (and potentially additional inputs if you have set any using the `additional_inputs` parameter), and will return the response, internally keeping track of the messages sent so far. [](https://github.com/gradio-app/gradio/assets/1778297/7b10d6db-6476-4e2e-bebd-ecda802c3b8f) @@ -251,4 +284,4 @@ def predict(message, history): gr.ChatInterface(predict).queue().launch() ``` -With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building more custom Chabot UI, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API. \ No newline at end of file +With those examples, you should be all set to create your own Gradio Chatbot demos soon! For building even more custom Chatbot applications, check out [a dedicated guide](/guides/creating-a-custom-chatbot-with-blocks) using the low-level `gr.Blocks()` API. \ No newline at end of file From 3d89981d54a1438ad6ac8a4cfadd280bf35c30bb Mon Sep 17 00:00:00 2001 From: gradio-pr-bot Date: Sat, 22 Jul 2023 00:07:33 +0000 Subject: [PATCH 09/13] add changeset --- .changeset/witty-pets-rhyme.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/witty-pets-rhyme.md diff --git a/.changeset/witty-pets-rhyme.md b/.changeset/witty-pets-rhyme.md new file mode 100644 index 0000000000000..c11f3c705842b --- /dev/null +++ b/.changeset/witty-pets-rhyme.md @@ -0,0 +1,5 @@ +--- +"gradio": minor +--- + +feat:WIP: adding `additional_inputs` to `gr.ChatInterface` From eb317a942f148680c7ce00ecd030e7330a5e9d7c Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 21 Jul 2023 21:19:27 -0400 Subject: [PATCH 10/13] Fix typos --- gradio/chat_interface.py | 3 +-- test/test_chat_interface.py | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index ca23d7d9f95c6..aaee1c820e68e 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -76,7 +76,7 @@ def __init__( fn: the function to wrap the chat interface around. Should accept two parameters: a string input message and list of two-element lists of the form [[user_message, bot_message], ...] representing the chat history, and return a string response. See the Chatbot documentation for more information on the chat history format. chatbot: an instance of the gr.Chatbot component to use for the chat interface, if you would like to customize the chatbot properties. If not provided, a default gr.Chatbot component will be created. textbox: an instance of the gr.Textbox component to use for the chat interface, if you would like to customize the textbox properties. If not provided, a default gr.Textbox component will be created. - additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed underthe chatbot, in an accordion. + additional_inputs: an instance or list of instances of gradio components (or their string shortcuts) to use as additional inputs to the chatbot. If components are not already rendered in a surrounding Blocks, then the components will be displayed under the chatbot, in an accordion. additional_inputs_accordion_name: the label of the accordion to use for additional inputs, only used if additional_inputs is provided. examples: sample inputs for the function; if provided, appear below the chatbot and can be clicked to populate the chatbot input. cache_examples: If True, caches examples in the server for fast runtime in examples. The default option in HuggingFace Spaces is True. The default option elsewhere is False. @@ -215,7 +215,6 @@ def __init__( inputs=[self.textbox] + self.additional_inputs, outputs=self.chatbot, fn=examples_fn, - # cache_examples=cache_examples ) any_unrendered_inputs = any( diff --git a/test/test_chat_interface.py b/test/test_chat_interface.py index 710c6243d56bf..c15ddc35856a3 100644 --- a/test/test_chat_interface.py +++ b/test/test_chat_interface.py @@ -1,8 +1,10 @@ +import tempfile from concurrent.futures import wait import pytest import gradio as gr +from gradio import helpers def invalid_fn(message): @@ -76,7 +78,8 @@ async def test_example_caching(self): assert prediction_hi[0][0] == ["hi", "hi hi"] @pytest.mark.asyncio - async def test_example_caching_with_streaming(self): + async def test_example_caching_with_streaming(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) chatbot = gr.ChatInterface( stream, examples=["hello", "hi"], cache_examples=True ) @@ -86,7 +89,8 @@ async def test_example_caching_with_streaming(self): assert prediction_hi[0][0] == ["hi", "hi"] @pytest.mark.asyncio - async def test_example_caching_with_additional_inputs(self): + async def test_example_caching_with_additional_inputs(self, monkeypatch): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) chatbot = gr.ChatInterface( echo_system_prompt_plus_message, additional_inputs=["textbox", "slider"], @@ -98,6 +102,26 @@ async def test_example_caching_with_additional_inputs(self): assert prediction_hello[0][0] == ["hello", "robot hello"] assert prediction_hi[0][0] == ["hi", "ro"] + @pytest.mark.asyncio + async def test_example_caching_with_additional_inputs_already_rendered( + self, monkeypatch + ): + monkeypatch.setattr(helpers, "CACHED_FOLDER", tempfile.mkdtemp()) + with gr.Blocks(): + with gr.Accordion("Inputs"): + text = gr.Textbox() + slider = gr.Slider() + chatbot = gr.ChatInterface( + echo_system_prompt_plus_message, + additional_inputs=[text, slider], + examples=[["hello", "robot", 100], ["hi", "robot", 2]], + cache_examples=True, + ) + prediction_hello = await chatbot.examples_handler.load_from_cache(0) + prediction_hi = await chatbot.examples_handler.load_from_cache(1) + assert prediction_hello[0][0] == ["hello", "robot hello"] + assert prediction_hi[0][0] == ["hi", "ro"] + class TestAPI: def test_get_api_info(self): From 10042856151b0ff79412613e6cb176dfc8642117 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Fri, 21 Jul 2023 21:33:53 -0400 Subject: [PATCH 11/13] Remove label --- gradio/chat_interface.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index aaee1c820e68e..9eb870f277137 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -142,7 +142,6 @@ def __init__( self.textbox = Textbox( container=False, show_label=False, - label="Message", placeholder="Type a message...", scale=7, autofocus=autofocus, From 2c2a3cc5fccfaafac36af5531b6b2bdd96d7c186 Mon Sep 17 00:00:00 2001 From: freddyaboulton Date: Mon, 24 Jul 2023 11:44:07 -0400 Subject: [PATCH 12/13] Revert "Remove label" This reverts commit 10042856151b0ff79412613e6cb176dfc8642117. --- gradio/chat_interface.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gradio/chat_interface.py b/gradio/chat_interface.py index 9eb870f277137..aaee1c820e68e 100644 --- a/gradio/chat_interface.py +++ b/gradio/chat_interface.py @@ -142,6 +142,7 @@ def __init__( self.textbox = Textbox( container=False, show_label=False, + label="Message", placeholder="Type a message...", scale=7, autofocus=autofocus, From d31cacfaddfbdcb50e417d41670d53435cf42fca Mon Sep 17 00:00:00 2001 From: gradio-pr-bot Date: Mon, 24 Jul 2023 15:46:00 +0000 Subject: [PATCH 13/13] add changeset --- .changeset/witty-pets-rhyme.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.changeset/witty-pets-rhyme.md b/.changeset/witty-pets-rhyme.md index c11f3c705842b..eaf2c16692e2f 100644 --- a/.changeset/witty-pets-rhyme.md +++ b/.changeset/witty-pets-rhyme.md @@ -2,4 +2,4 @@ "gradio": minor --- -feat:WIP: adding `additional_inputs` to `gr.ChatInterface` +feat:Adds `additional_inputs` to `gr.ChatInterface`