diff --git a/panel/chat/feed.py b/panel/chat/feed.py index df46232a16..d7301cb8a8 100644 --- a/panel/chat/feed.py +++ b/panel/chat/feed.py @@ -10,7 +10,10 @@ from enum import Enum from functools import partial -from inspect import isasyncgen, isawaitable, isgenerator +from inspect import ( + isasyncgen, isasyncgenfunction, isawaitable, iscoroutinefunction, + isgenerator, +) from io import BytesIO from typing import ( TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, @@ -479,12 +482,24 @@ async def _schedule_placeholder( return start = asyncio.get_event_loop().time() - while not self._callback_state == CallbackState.IDLE and num_entries == len(self._chat_log): + while not task.done() and num_entries == len(self._chat_log): duration = asyncio.get_event_loop().time() - start if duration > self.placeholder_threshold: self.append(self._placeholder) return - await asyncio.sleep(0.28) + await asyncio.sleep(0.1) + + async def _handle_callback(self, message, loop): + callback_args = self._gather_callback_args(message) + if iscoroutinefunction(self.callback): + response = await self.callback(*callback_args) + elif isasyncgenfunction(self.callback): + response = self.callback(*callback_args) + else: + response = await loop.run_in_executor( + None, partial(self.callback, *callback_args) + ) + await self._serialize_response(response) async def _prepare_response(self, _) -> None: """ @@ -505,19 +520,12 @@ async def _prepare_response(self, _) -> None: return num_entries = len(self._chat_log) - callback_args = self._gather_callback_args(message) loop = asyncio.get_event_loop() - if asyncio.iscoroutinefunction(self.callback): - future = loop.create_task(self.callback(*callback_args)) - else: - future = loop.run_in_executor(None, partial(self.callback, *callback_args)) + future = loop.create_task(self._handle_callback(message, loop)) self._callback_future = future - await self._schedule_placeholder(future, num_entries) - - if not future.cancelled(): - await future - response = future.result() - await self._serialize_response(response) + await asyncio.gather( + self._schedule_placeholder(future, num_entries), future, + ) except StopCallback: # callback was stopped by user self._callback_state = CallbackState.STOPPED @@ -536,10 +544,16 @@ async def _prepare_response(self, _) -> None: else: raise e finally: - with param.parameterized.batch_call_watchers(self): - self._replace_placeholder(None) - self._callback_state = CallbackState.IDLE - self.disabled = self._was_disabled + await self._cleanup_response() + + async def _cleanup_response(self): + """ + Events to always execute after the callback is done. + """ + with param.parameterized.batch_call_watchers(self): + self._replace_placeholder(None) + self._callback_state = CallbackState.IDLE + self.disabled = self._was_disabled # Public API diff --git a/panel/chat/interface.py b/panel/chat/interface.py index a612a53be9..097758ca72 100644 --- a/panel/chat/interface.py +++ b/panel/chat/interface.py @@ -590,3 +590,10 @@ async def _update_input_disabled(self): with param.parameterized.batch_call_watchers(self): self._buttons["send"].visible = False self._buttons["stop"].visible = True + + async def _cleanup_response(self): + """ + Events to always execute after the callback is done. + """ + await super()._cleanup_response() + await self._update_input_disabled() diff --git a/panel/tests/chat/test_feed.py b/panel/tests/chat/test_feed.py index a7769a301f..c709217fb7 100644 --- a/panel/tests/chat/test_feed.py +++ b/panel/tests/chat/test_feed.py @@ -20,6 +20,8 @@ "max_width": 201, } +ChatFeed.callback_exception = "raise" + @pytest.fixture def chat_feed(): @@ -525,7 +527,6 @@ async def echo(contents, user, instance): assert len(chat_feed.objects) == 2 assert chat_feed.objects[1].object == "Message" - @pytest.mark.asyncio def test_generator(self, chat_feed): async def echo(contents, user, instance): message = "" @@ -580,7 +581,6 @@ def echo(contents, user, instance): chat_feed.callback = echo chat_feed.send("Message", respond=True) assert chat_feed._placeholder not in chat_feed._chat_log - # append sent message and placeholder def test_placeholder_threshold_under(self, chat_feed): async def echo(contents, user, instance): @@ -617,13 +617,13 @@ async def echo(contents, user, instance): def test_placeholder_threshold_exceed_generator(self, chat_feed): async def echo(contents, user, instance): - assert instance._placeholder not in instance._chat_log + await async_wait_until(lambda: instance._placeholder not in instance._chat_log) await asyncio.sleep(0.5) - assert instance._placeholder in instance._chat_log + await async_wait_until(lambda: instance._placeholder in instance._chat_log) yield "hello testing" - assert instance._placeholder not in instance._chat_log + await async_wait_until(lambda: instance._placeholder not in instance._chat_log) - chat_feed.placeholder_threshold = 0.2 + chat_feed.placeholder_threshold = 1 chat_feed.callback = echo chat_feed.send("Message", respond=True) assert chat_feed._placeholder not in chat_feed._chat_log @@ -712,7 +712,10 @@ async def callback(msg, user, instance): yield "B" chat_feed.callback = callback - chat_feed.send("Message", respond=True) + try: + chat_feed.send("Message", respond=True) + except asyncio.CancelledError: # tests pick up this error + pass # use sleep here instead of wait for because # the callback is timed and I want to confirm stop works time.sleep(1) @@ -726,7 +729,10 @@ async def callback(msg, user, instance): instance.stream("B", message=message) chat_feed.callback = callback - chat_feed.send("Message", respond=True) + try: + chat_feed.send("Message", respond=True) + except asyncio.CancelledError: + pass # use sleep here instead of wait for because # the callback is timed and I want to confirm stop works time.sleep(1) @@ -740,7 +746,10 @@ def callback(msg, user, instance): instance.stream("B", message=message) # should not reach this point chat_feed.callback = callback - chat_feed.send("Message", respond=True) + try: + chat_feed.send("Message", respond=True) + except asyncio.CancelledError: + pass # use sleep here instead of wait for because # the callback is timed and I want to confirm stop works time.sleep(1) diff --git a/panel/tests/chat/test_interface.py b/panel/tests/chat/test_interface.py index f056f14467..3594f168e4 100644 --- a/panel/tests/chat/test_interface.py +++ b/panel/tests/chat/test_interface.py @@ -1,3 +1,5 @@ +import asyncio + from io import BytesIO import pytest @@ -6,10 +8,12 @@ from panel.chat.interface import ChatInterface from panel.layout import Row, Tabs from panel.pane import Image -from panel.tests.util import wait_until +from panel.tests.util import async_wait_until, wait_until from panel.widgets.button import Button from panel.widgets.input import FileInput, TextAreaInput, TextInput +ChatInterface.callback_exception = "raise" + class TestChatInterface: @pytest.fixture @@ -88,12 +92,10 @@ def test_click_send(self, chat_interface: ChatInterface): def test_show_stop_disabled(self, chat_interface: ChatInterface): async def callback(msg, user, instance): yield "A" - send_button = chat_interface._input_layout[1] - stop_button = chat_interface._input_layout[2] - assert send_button.name == "Send" - assert stop_button.name == "Stop" - assert send_button.visible - assert not send_button.disabled + send_button = instance._buttons["send"] + stop_button = instance._buttons["stop"] + wait_until(lambda: send_button.visible) + wait_until(lambda: send_button.disabled) # should be disabled while callback is running assert not stop_button.visible yield "B" # should not stream this @@ -110,12 +112,10 @@ async def callback(msg, user, instance): def test_show_stop_for_async(self, chat_interface: ChatInterface): async def callback(msg, user, instance): - send_button = instance._input_layout[1] - stop_button = instance._input_layout[2] - assert send_button.name == "Send" - assert stop_button.name == "Stop" - assert not send_button.visible - assert stop_button.visible + send_button = instance._buttons["send"] + stop_button = instance._buttons["stop"] + await async_wait_until(lambda: stop_button.visible) + await async_wait_until(lambda: not send_button.visible) chat_interface.callback = callback chat_interface.send("Message", respond=True) @@ -124,12 +124,10 @@ async def callback(msg, user, instance): def test_show_stop_for_sync(self, chat_interface: ChatInterface): def callback(msg, user, instance): - send_button = instance._input_layout[1] - stop_button = instance._input_layout[2] - assert send_button.name == "Send" - assert stop_button.name == "Stop" - assert not send_button.visible - assert stop_button.visible + send_button = instance._buttons["send"] + stop_button = instance._buttons["stop"] + wait_until(lambda: stop_button.visible) + wait_until(lambda: not send_button.visible) chat_interface.callback = callback chat_interface.send("Message", respond=True) @@ -138,25 +136,21 @@ def callback(msg, user, instance): def test_click_stop(self, chat_interface: ChatInterface): async def callback(msg, user, instance): - send_button = instance._input_layout[1] - stop_button = instance._input_layout[2] - assert send_button.name == "Send" - assert stop_button.name == "Stop" - assert not send_button.visible - assert stop_button.visible - wait_until(lambda: len(instance.objects) == 2) - assert instance._placeholder in instance.objects + send_button = instance._buttons["send"] + stop_button = instance._buttons["stop"] + await async_wait_until(lambda: stop_button.visible) + await async_wait_until(lambda: not send_button.visible) instance._click_stop(None) - assert send_button.visible - assert not send_button.disabled - assert not stop_button.visible - assert instance._placeholder not in instance.objects chat_interface.callback = callback chat_interface.placeholder_threshold = 0.001 - chat_interface.send("Message", respond=True) - send_button = chat_interface._input_layout[1] - assert not send_button.disabled + try: + chat_interface.send("Message", respond=True) + except asyncio.exceptions.CancelledError: + pass + wait_until(lambda: not chat_interface._buttons["send"].disabled) + wait_until(lambda: chat_interface._buttons["send"].visible) + wait_until(lambda: not chat_interface._buttons["stop"].visible) @pytest.mark.parametrize("widget", [TextInput(), TextAreaInput()]) def test_auto_send_types(self, chat_interface: ChatInterface, widget):