Skip to content

Commit

Permalink
Make it easy to prompt user for input in ChatFeed (#7148)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Aug 22, 2024
1 parent e9b62ec commit 3b0d8c2
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 1 deletion.
117 changes: 117 additions & 0 deletions examples/reference/chat/ChatFeed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,123 @@
"See [`ChatStep`](ChatStep.ipynb) for more details on how to use those components."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Prompt User\n",
"\n",
"It is possible to temporarily pause the execution of code and prompt the user to answer a question, or fill out a form, using `prompt_user`, which accepts any Panel `component` and a follow-up `callback` (with `component` and `instance` as args) to execute upon submission."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def repeat_answer(component, instance):\n",
" contents = component.value\n",
" instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n",
"\n",
"\n",
"def show_interest(contents, user, instance):\n",
" if \"ice\" in contents or \"cream\" in contents:\n",
" answer_input = pn.widgets.TextInput(\n",
" placeholder=\"Enter your favorite ice cream flavor\"\n",
" )\n",
" instance.prompt_user(answer_input, callback=repeat_answer)\n",
" else:\n",
" return \"I'm not interested in that topic.\"\n",
"\n",
"\n",
"chat_feed = pn.chat.ChatFeed(\n",
" callback=show_interest,\n",
" callback_user=\"Ice Cream Bot\",\n",
")\n",
"chat_feed"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat_feed.send(\"food\");"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also set a `predicate` to evaluate the component's state, e.g. widget has value. If provided, the submit button will be enabled when the predicate returns `True`. The `predicate` should accept the component as an argument."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_chocolate(component):\n",
" return \"chocolate\" in component.value.lower()\n",
"\n",
"\n",
"def repeat_answer(component, instance):\n",
" contents = component.value\n",
" instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n",
"\n",
"\n",
"def show_interest(contents, user, instance):\n",
" if \"ice\" in contents or \"cream\" in contents:\n",
" answer_input = pn.widgets.TextInput(\n",
" placeholder=\"Enter your favorite ice cream flavor\"\n",
" )\n",
" instance.prompt_user(answer_input, callback=repeat_answer, predicate=is_chocolate)\n",
" else:\n",
" return \"I'm not interested in that topic.\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also set a `timeout` in seconds and `timeout_message` to prevent submission after a certain time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_chocolate(component):\n",
" return \"chocolate\" in component.value.lower()\n",
"\n",
"\n",
"def repeat_answer(component, instance):\n",
" contents = component.value\n",
" instance.send(f\"Wow, {contents}, that's my favorite flavor too!\", respond=False, user=\"Ice Cream Bot\")\n",
"\n",
"\n",
"def show_interest(contents, user, instance):\n",
" if \"ice\" in contents or \"cream\" in contents:\n",
" answer_input = pn.widgets.TextInput(\n",
" placeholder=\"Enter your favorite ice cream flavor\"\n",
" )\n",
" instance.prompt_user(answer_input, callback=repeat_answer, predicate=is_chocolate, timeout=10, timeout_message=\"You're too slow!\")\n",
" else:\n",
" return \"I'm not interested in that topic.\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Lastly, use `button_params` and `timeout_button_params` to customize the appearance of the buttons and timeout button, respectively."
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
87 changes: 86 additions & 1 deletion panel/chat/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@

from .._param import Margin
from ..io.resources import CDN_DIST
from ..layout import Column, Feed, ListPanel
from ..layout import (
Column, Feed, ListPanel, WidgetBox,
)
from ..layout.card import Card
from ..layout.spacer import VSpacer
from ..pane.image import SVG, ImageBase
from ..pane.markup import HTML, Markdown
from ..util import to_async_gen
from ..viewable import Children
from ..widgets import Widget
from ..widgets.button import Button
from .icon import ChatReactionIcons
from .message import ChatMessage
from .step import ChatStep
Expand Down Expand Up @@ -198,6 +202,8 @@ class ChatFeed(ListPanel):
_callback_state = param.ObjectSelector(objects=list(CallbackState), doc="""
The current state of the callback.""")

_prompt_trigger = param.Event(doc="Triggers the prompt input.")

_callback_trigger = param.Event(doc="Triggers the callback to respond.")

_post_hook_trigger = param.Event(doc="Triggers the append callback.")
Expand Down Expand Up @@ -807,6 +813,85 @@ def add_step(
self._chat_log.scroll_to_latest()
return step

def prompt_user(
self,
component: Widget | ListPanel,
callback: Callable | None = None,
predicate: Callable | None = None,
timeout: int = 120,
timeout_message: str = "Timed out",
button_params: dict | None = None,
timeout_button_params: dict | None = None,
**send_kwargs
) -> None:
"""
Prompts the user to interact with a form component.
Arguments
---------
component : Widget | ListPanel
The component to prompt the user with.
callback : Callable
The callback to execute once the user submits the form.
The callback should accept two arguments: the component
and the ChatFeed instance.
predicate : Callable | None
A predicate to evaluate the component's state, e.g. widget has value.
If provided, the button will be enabled when the predicate returns True.
The predicate should accept the component as an argument.
timeout : int
The duration in seconds to wait before timing out.
timeout_message : str
The message to display when the timeout is reached.
button_params : dict | None
Additional parameters to pass to the submit button.
timeout_button_params : dict | None
Additional parameters to pass to the timeout button.
"""
async def _prepare_prompt(*_) -> None:
input_button_params = button_params or {}
if "name" not in input_button_params:
input_button_params["name"] = "Submit"
if "margin" not in input_button_params:
input_button_params["margin"] = (5, 10)
if "button_type" not in input_button_params:
input_button_params["button_type"] = "primary"
if "icon" not in input_button_params:
input_button_params["icon"] = "check"
submit_button = Button(**input_button_params)

form = WidgetBox(component, submit_button, margin=(5, 10), css_classes=["message"])
if "user" not in send_kwargs:
send_kwargs["user"] = "Input"
self.send(form, respond=False, **send_kwargs)

for _ in range(timeout * 10): # sleeping for 0.1 seconds
is_fulfilled = predicate(component) if predicate else True
submit_button.disabled = not is_fulfilled
if submit_button.clicks > 0:
with param.parameterized.batch_call_watchers(self):
submit_button.visible = False
form.disabled = True
if callback is not None:
result = callback(component, self)
if isawaitable(result):
await result
break
await asyncio.sleep(0.1)
else:
input_timeout_button_params = timeout_button_params or {}
if "name" not in input_timeout_button_params:
input_timeout_button_params["name"] = timeout_message
if "button_type" not in input_timeout_button_params:
input_timeout_button_params["button_type"] = "light"
if "icon" not in input_timeout_button_params:
input_timeout_button_params["icon"] = "x"
with param.parameterized.batch_call_watchers(self):
submit_button.param.update(**input_timeout_button_params)
form.disabled = True

param.parameterized.async_executor(_prepare_prompt)

def respond(self):
"""
Executes the callback with the latest message in the chat log.
Expand Down
2 changes: 2 additions & 0 deletions panel/chat/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
SYSTEM_LOGO = "⚙️"
ERROR_LOGO = "❌"
HELP_LOGO = "❓"
INPUT_LOGO = "❗"
GPT_3_LOGO = "{dist_path}assets/logo/gpt-3.svg"
GPT_4_LOGO = "{dist_path}assets/logo/gpt-4.svg"
WOLFRAM_LOGO = "{dist_path}assets/logo/wolfram.svg"
Expand Down Expand Up @@ -79,6 +80,7 @@
"exception": ERROR_LOGO,
"error": ERROR_LOGO,
"help": HELP_LOGO,
"input": INPUT_LOGO,
# Human
"adult": "🧑",
"baby": "👶",
Expand Down
134 changes: 134 additions & 0 deletions panel/tests/chat/test_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,140 @@ def test_update_chat_log_params(self, chat_feed):
assert chat_feed._chat_log.scroll_button_threshold == 10
assert chat_feed._chat_log.auto_scroll_limit == 10


@pytest.mark.xdist_group("chat")
class TestChatFeedPromptUser:

async def test_prompt_user_basic(self, chat_feed):
text_input = TextInput()

def callback(component, feed):
feed.send(component.value)

async def prompt_and_submit():
chat_feed.prompt_user(text_input, callback)
await async_wait_until(lambda: len(chat_feed.objects) == 1)
text_input.value = "test input"
submit_button = chat_feed.objects[-1].object[1]
submit_button.clicks += 1
await async_wait_until(lambda: len(chat_feed.objects) == 2)

await asyncio.wait_for(prompt_and_submit(), timeout=5.0)
assert chat_feed.objects[-1].object == "test input"

async def test_prompt_user_with_predicate(self, chat_feed):
text_input = TextInput()

def predicate(component):
return len(component.value) > 5

def callback(component, feed):
feed.send(component.value)

async def prompt_and_submit():
chat_feed.prompt_user(text_input, callback, predicate=predicate)
await async_wait_until(lambda: len(chat_feed.objects) == 1)

text_input.value = "short"
submit_button = chat_feed.objects[-1].object[1]
assert submit_button.disabled

text_input.value = "long enough"
await async_wait_until(lambda: not submit_button.disabled)

submit_button.clicks += 1
await async_wait_until(lambda: len(chat_feed.objects) == 2)

await asyncio.wait_for(prompt_and_submit(), timeout=5.0)
assert chat_feed.objects[-1].object == "long enough"

async def test_prompt_user_timeout(self, chat_feed):
text_input = TextInput()

def callback(component, feed):
pytest.fail("Callback should not be called on timeout")

async def prompt_and_wait():
chat_feed.prompt_user(text_input, callback, timeout=1)
await async_wait_until(lambda: len(chat_feed.objects) == 1)
await async_wait_until(lambda: chat_feed.objects[-1].object[1].disabled)

await asyncio.wait_for(prompt_and_wait(), timeout=5.0)

submit_button = chat_feed.objects[-1].object[1]
assert submit_button.name == "Timed out"
assert submit_button.button_type == "light"
assert submit_button.icon == "x"

async def test_prompt_user_custom_button_params(self, chat_feed):
text_input = TextInput()

def callback(component, feed):
feed.send(component.value)

custom_button_params = {
"name": "Custom Submit",
"button_type": "success",
"icon": "arrow-right"
}

async def prompt_and_check():
chat_feed.prompt_user(text_input, callback, button_params=custom_button_params)
await async_wait_until(lambda: len(chat_feed.objects) == 1)

await asyncio.wait_for(prompt_and_check(), timeout=5.0)

submit_button = chat_feed.objects[-1].object[1]
assert submit_button.name == "Custom Submit"
assert submit_button.button_type == "success"
assert submit_button.icon == "arrow-right"

async def test_prompt_user_custom_timeout_button_params(self, chat_feed):
text_input = TextInput()

def callback(component, feed):
pytest.fail("Callback should not be called on timeout")

custom_timeout_params = {
"name": "Custom Timeout",
"button_type": "danger",
"icon": "alert-triangle"
}

async def prompt_and_wait():
chat_feed.prompt_user(text_input, callback, timeout=1, timeout_button_params=custom_timeout_params)
await async_wait_until(lambda: len(chat_feed.objects) == 1)
await async_wait_until(lambda: chat_feed.objects[-1].object[1].disabled)

await asyncio.wait_for(prompt_and_wait(), timeout=5.0)

submit_button = chat_feed.objects[-1].object[1]
assert submit_button.name == "Custom Timeout"
assert submit_button.button_type == "danger"
assert submit_button.icon == "alert-triangle"

async def test_prompt_user_async(self, chat_feed):
text_input = TextInput()

async def async_callback(component, feed):
await asyncio.sleep(0.1)
feed.send("Callback executed")

async def prompt_and_submit():
chat_feed.prompt_user(text_input, async_callback)
await async_wait_until(lambda: len(chat_feed.objects) == 1)

submit_button = chat_feed.objects[-1].object[1]
submit_button.clicks += 1

await async_wait_until(lambda: len(chat_feed.objects) == 2)

await asyncio.wait_for(prompt_and_submit(), timeout=5.0)

assert chat_feed.objects[-1].object == "Callback executed"
assert chat_feed.objects[-2].object.disabled == True


@pytest.mark.xdist_group("chat")
class TestChatFeedCallback:

Expand Down

0 comments on commit 3b0d8c2

Please sign in to comment.