diff --git a/examples/reference/chat/ChatFeed.ipynb b/examples/reference/chat/ChatFeed.ipynb index 5b9b22c594..bf117aeb7c 100644 --- a/examples/reference/chat/ChatFeed.ipynb +++ b/examples/reference/chat/ChatFeed.ipynb @@ -731,6 +731,123 @@ "See [`ChatStep`](ChatStep.ipynb) for more details on how to use those components." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt User\n", + "\n", + "It is possible to temporarily pause the execution of code and prompt the user to answer a question, or fill out a form, using `prompt_user`, which accepts any Panel `component` and a follow-up `callback` (with `component` and `instance` as args) to execute upon submission." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def repeat_answer(component, instance):\n", + " contents = component.value\n", + " instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n", + "\n", + "\n", + "def show_interest(contents, user, instance):\n", + " if \"ice\" in contents or \"cream\" in contents:\n", + " answer_input = pn.widgets.TextInput(\n", + " placeholder=\"Enter your favorite ice cream flavor\"\n", + " )\n", + " instance.prompt_user(answer_input, callback=repeat_answer)\n", + " else:\n", + " return \"I'm not interested in that topic.\"\n", + "\n", + "\n", + "chat_feed = pn.chat.ChatFeed(\n", + " callback=show_interest,\n", + " callback_user=\"Ice Cream Bot\",\n", + ")\n", + "chat_feed" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "chat_feed.send(\"food\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also set a `predicate` to evaluate the component's state, e.g. widget has value. If provided, the submit button will be enabled when the predicate returns `True`. The `predicate` should accept the component as an argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def is_chocolate(component):\n", + " return \"chocolate\" in component.value.lower()\n", + "\n", + "\n", + "def repeat_answer(component, instance):\n", + " contents = component.value\n", + " instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n", + "\n", + "\n", + "def show_interest(contents, user, instance):\n", + " if \"ice\" in contents or \"cream\" in contents:\n", + " answer_input = pn.widgets.TextInput(\n", + " placeholder=\"Enter your favorite ice cream flavor\"\n", + " )\n", + " instance.prompt_user(answer_input, callback=repeat_answer, predicate=is_chocolate)\n", + " else:\n", + " return \"I'm not interested in that topic.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also set a `timeout` in seconds and `timeout_message` to prevent submission after a certain time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def is_chocolate(component):\n", + " return \"chocolate\" in component.value.lower()\n", + "\n", + "\n", + "def repeat_answer(component, instance):\n", + " contents = component.value\n", + " instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n", + "\n", + "\n", + "def show_interest(contents, user, instance):\n", + " if \"ice\" in contents or \"cream\" in contents:\n", + " answer_input = pn.widgets.TextInput(\n", + " placeholder=\"Enter your favorite ice cream flavor\"\n", + " )\n", + " instance.prompt_user(answer_input, callback=repeat_answer, predicate=is_chocolate, timeout=10, timeout_message=\"You're too slow!\")\n", + " else:\n", + " return \"I'm not interested in that topic.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, use `button_params` and `timeout_button_params` to customize the appearance of the buttons and timeout button, respectively." + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/panel/chat/feed.py b/panel/chat/feed.py index 955f04e0b4..a0924f46e7 100644 --- a/panel/chat/feed.py +++ b/panel/chat/feed.py @@ -22,13 +22,17 @@ from .._param import Margin from ..io.resources import CDN_DIST -from ..layout import Column, Feed, ListPanel +from ..layout import ( + Column, Feed, ListPanel, WidgetBox, +) from ..layout.card import Card from ..layout.spacer import VSpacer from ..pane.image import SVG, ImageBase from ..pane.markup import HTML, Markdown from ..util import to_async_gen from ..viewable import Children +from ..widgets import Widget +from ..widgets.button import Button from .icon import ChatReactionIcons from .message import ChatMessage from .step import ChatStep @@ -198,6 +202,8 @@ class ChatFeed(ListPanel): _callback_state = param.ObjectSelector(objects=list(CallbackState), doc=""" The current state of the callback.""") + _prompt_trigger = param.Event(doc="Triggers the prompt input.") + _callback_trigger = param.Event(doc="Triggers the callback to respond.") _post_hook_trigger = param.Event(doc="Triggers the append callback.") @@ -807,6 +813,85 @@ def add_step( self._chat_log.scroll_to_latest() return step + def prompt_user( + self, + component: Widget | ListPanel, + callback: Callable | None = None, + predicate: Callable | None = None, + timeout: int = 120, + timeout_message: str = "Timed out", + button_params: dict | None = None, + timeout_button_params: dict | None = None, + **send_kwargs + ) -> None: + """ + Prompts the user to interact with a form component. + + Arguments + --------- + component : Widget | ListPanel + The component to prompt the user with. + callback : Callable + The callback to execute once the user submits the form. + The callback should accept two arguments: the component + and the ChatFeed instance. + predicate : Callable | None + A predicate to evaluate the component's state, e.g. widget has value. + If provided, the button will be enabled when the predicate returns True. + The predicate should accept the component as an argument. + timeout : int + The duration in seconds to wait before timing out. + timeout_message : str + The message to display when the timeout is reached. + button_params : dict | None + Additional parameters to pass to the submit button. + timeout_button_params : dict | None + Additional parameters to pass to the timeout button. + """ + async def _prepare_prompt(*_) -> None: + input_button_params = button_params or {} + if "name" not in input_button_params: + input_button_params["name"] = "Submit" + if "margin" not in input_button_params: + input_button_params["margin"] = (5, 10) + if "button_type" not in input_button_params: + input_button_params["button_type"] = "primary" + if "icon" not in input_button_params: + input_button_params["icon"] = "check" + submit_button = Button(**input_button_params) + + form = WidgetBox(component, submit_button, margin=(5, 10), css_classes=["message"]) + if "user" not in send_kwargs: + send_kwargs["user"] = "Input" + self.send(form, respond=False, **send_kwargs) + + for _ in range(timeout * 10): # sleeping for 0.1 seconds + is_fulfilled = predicate(component) if predicate else True + submit_button.disabled = not is_fulfilled + if submit_button.clicks > 0: + with param.parameterized.batch_call_watchers(self): + submit_button.visible = False + form.disabled = True + if callback is not None: + result = callback(component, self) + if isawaitable(result): + await result + break + await asyncio.sleep(0.1) + else: + input_timeout_button_params = timeout_button_params or {} + if "name" not in input_timeout_button_params: + input_timeout_button_params["name"] = timeout_message + if "button_type" not in input_timeout_button_params: + input_timeout_button_params["button_type"] = "light" + if "icon" not in input_timeout_button_params: + input_timeout_button_params["icon"] = "x" + with param.parameterized.batch_call_watchers(self): + submit_button.param.update(**input_timeout_button_params) + form.disabled = True + + param.parameterized.async_executor(_prepare_prompt) + def respond(self): """ Executes the callback with the latest message in the chat log. diff --git a/panel/chat/message.py b/panel/chat/message.py index b8d22fe2c4..67d8cc5c71 100644 --- a/panel/chat/message.py +++ b/panel/chat/message.py @@ -50,6 +50,7 @@ SYSTEM_LOGO = "⚙️" ERROR_LOGO = "❌" HELP_LOGO = "❓" +INPUT_LOGO = "❗" GPT_3_LOGO = "{dist_path}assets/logo/gpt-3.svg" GPT_4_LOGO = "{dist_path}assets/logo/gpt-4.svg" WOLFRAM_LOGO = "{dist_path}assets/logo/wolfram.svg" @@ -79,6 +80,7 @@ "exception": ERROR_LOGO, "error": ERROR_LOGO, "help": HELP_LOGO, + "input": INPUT_LOGO, # Human "adult": "🧑", "baby": "👶", diff --git a/panel/tests/chat/test_feed.py b/panel/tests/chat/test_feed.py index 096afb4cd0..64aa7304d4 100644 --- a/panel/tests/chat/test_feed.py +++ b/panel/tests/chat/test_feed.py @@ -608,6 +608,140 @@ def test_update_chat_log_params(self, chat_feed): assert chat_feed._chat_log.scroll_button_threshold == 10 assert chat_feed._chat_log.auto_scroll_limit == 10 + +@pytest.mark.xdist_group("chat") +class TestChatFeedPromptUser: + + async def test_prompt_user_basic(self, chat_feed): + text_input = TextInput() + + def callback(component, feed): + feed.send(component.value) + + async def prompt_and_submit(): + chat_feed.prompt_user(text_input, callback) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + text_input.value = "test input" + submit_button = chat_feed.objects[-1].object[1] + submit_button.clicks += 1 + await async_wait_until(lambda: len(chat_feed.objects) == 2) + + await asyncio.wait_for(prompt_and_submit(), timeout=5.0) + assert chat_feed.objects[-1].object == "test input" + + async def test_prompt_user_with_predicate(self, chat_feed): + text_input = TextInput() + + def predicate(component): + return len(component.value) > 5 + + def callback(component, feed): + feed.send(component.value) + + async def prompt_and_submit(): + chat_feed.prompt_user(text_input, callback, predicate=predicate) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + + text_input.value = "short" + submit_button = chat_feed.objects[-1].object[1] + assert submit_button.disabled + + text_input.value = "long enough" + await async_wait_until(lambda: not submit_button.disabled) + + submit_button.clicks += 1 + await async_wait_until(lambda: len(chat_feed.objects) == 2) + + await asyncio.wait_for(prompt_and_submit(), timeout=5.0) + assert chat_feed.objects[-1].object == "long enough" + + async def test_prompt_user_timeout(self, chat_feed): + text_input = TextInput() + + def callback(component, feed): + pytest.fail("Callback should not be called on timeout") + + async def prompt_and_wait(): + chat_feed.prompt_user(text_input, callback, timeout=1) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + await async_wait_until(lambda: chat_feed.objects[-1].object[1].disabled) + + await asyncio.wait_for(prompt_and_wait(), timeout=5.0) + + submit_button = chat_feed.objects[-1].object[1] + assert submit_button.name == "Timed out" + assert submit_button.button_type == "light" + assert submit_button.icon == "x" + + async def test_prompt_user_custom_button_params(self, chat_feed): + text_input = TextInput() + + def callback(component, feed): + feed.send(component.value) + + custom_button_params = { + "name": "Custom Submit", + "button_type": "success", + "icon": "arrow-right" + } + + async def prompt_and_check(): + chat_feed.prompt_user(text_input, callback, button_params=custom_button_params) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + + await asyncio.wait_for(prompt_and_check(), timeout=5.0) + + submit_button = chat_feed.objects[-1].object[1] + assert submit_button.name == "Custom Submit" + assert submit_button.button_type == "success" + assert submit_button.icon == "arrow-right" + + async def test_prompt_user_custom_timeout_button_params(self, chat_feed): + text_input = TextInput() + + def callback(component, feed): + pytest.fail("Callback should not be called on timeout") + + custom_timeout_params = { + "name": "Custom Timeout", + "button_type": "danger", + "icon": "alert-triangle" + } + + async def prompt_and_wait(): + chat_feed.prompt_user(text_input, callback, timeout=1, timeout_button_params=custom_timeout_params) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + await async_wait_until(lambda: chat_feed.objects[-1].object[1].disabled) + + await asyncio.wait_for(prompt_and_wait(), timeout=5.0) + + submit_button = chat_feed.objects[-1].object[1] + assert submit_button.name == "Custom Timeout" + assert submit_button.button_type == "danger" + assert submit_button.icon == "alert-triangle" + + async def test_prompt_user_async(self, chat_feed): + text_input = TextInput() + + async def async_callback(component, feed): + await asyncio.sleep(0.1) + feed.send("Callback executed") + + async def prompt_and_submit(): + chat_feed.prompt_user(text_input, async_callback) + await async_wait_until(lambda: len(chat_feed.objects) == 1) + + submit_button = chat_feed.objects[-1].object[1] + submit_button.clicks += 1 + + await async_wait_until(lambda: len(chat_feed.objects) == 2) + + await asyncio.wait_for(prompt_and_submit(), timeout=5.0) + + assert chat_feed.objects[-1].object == "Callback executed" + assert chat_feed.objects[-2].object.disabled == True + + @pytest.mark.xdist_group("chat") class TestChatFeedCallback: