Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix stream ChatMessage for ChatInterface and mention serialize #6452

Merged
merged 3 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/reference/chat/ChatFeed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"##### Core\n",
"\n",
"* **`send`**: Sends a value and creates a new message in the chat log. If `respond` is `True`, additionally executes the callback, if provided.\n",
"* **`serialize`**: Exports the chat log as a dict; primarily for use with `transformers`.\n",
"* **`stream`**: Streams a token and updates the provided message, if provided. Otherwise creates a new message in the chat log, so be sure the returned message is passed back into the method, e.g. `message = chat.stream(token, message=message)`. This method is primarily for outputs that are not generators--notably LangChain. For most cases, use the send method instead.\n",
"\n",
"##### Other\n",
Expand Down Expand Up @@ -696,6 +697,63 @@
"chat_feed.serialize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the output is complex, you can pass a `custom_serializer` to only keep the text part."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"complex_output = pn.Tabs((\"Code\", \"`print('Hello World)`\"), (\"Output\", \"Hello World\"))\n",
"chat_feed = pn.chat.ChatFeed(pn.chat.ChatMessage(complex_output))\n",
"chat_feed"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's the output without:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chat_feed.serialize()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's the output with a `custom_serializer`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def custom_serializer(obj):\n",
" if isinstance(obj, pn.Tabs):\n",
" # only keep the first tab's content\n",
" return obj[0].object\n",
" # fall back to the default serialization\n",
" return obj.serialize()\n",
"\n",
"chat_feed.serialize(custom_serializer=custom_serializer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
2 changes: 1 addition & 1 deletion panel/chat/feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def send(

def stream(
self,
value: str,
value: str | dict | ChatMessage,
user: str | None = None,
avatar: str | bytes | BytesIO | None = None,
message: ChatMessage | None = None,
Expand Down
9 changes: 7 additions & 2 deletions panel/chat/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def send(

def stream(
self,
value: str,
value: str | dict | ChatMessage,
user: str | None = None,
avatar: str | bytes | BytesIO | None = None,
message: ChatMessage | None = None,
Expand Down Expand Up @@ -675,4 +675,9 @@ def stream(
-------
The message that was updated.
"""
return super().stream(value, user=user or self.user, avatar=avatar or self.avatar, message=message, replace=replace)
if not isinstance(value, ChatMessage):
# ChatMessage cannot set user or avatar when explicitly streaming
# so only set to the default when not a ChatMessage
user = user or self.user
avatar = avatar or self.avatar
return super().stream(value, user=user, avatar=avatar, message=message, replace=replace)
7 changes: 6 additions & 1 deletion panel/tests/chat/test_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_stream_dict_with_user_avatar_override(self, chat_feed):
assert chat_feed.objects[0].user == user
assert chat_feed.objects[0].avatar == avatar

def test_stream_entry(self, chat_feed):
def test_stream_message(self, chat_feed):
message = ChatMessage("Streaming message", user="Person", avatar="P")
chat_feed.stream(message)
wait_until(lambda: len(chat_feed.objects) == 1)
Expand All @@ -236,6 +236,11 @@ def test_stream_entry(self, chat_feed):
assert chat_feed.objects[0].user == "Person"
assert chat_feed.objects[0].avatar == "P"

def test_stream_message_error_passed_user_avatar(self, chat_feed):
message = ChatMessage("Streaming message", user="Person", avatar="P")
with pytest.raises(ValueError, match="Cannot set user or avatar"):
chat_feed.stream(message, user="Bob", avatar="👨")

def test_stream_replace(self, chat_feed):
message = chat_feed.stream("Hello")
wait_until(lambda: len(chat_feed.objects) == 1)
Expand Down
20 changes: 20 additions & 0 deletions panel/tests/chat/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from panel.chat.input import ChatAreaInput
from panel.chat.interface import ChatInterface
from panel.chat.message import ChatMessage
from panel.layout import Row, Tabs
from panel.pane import Image
from panel.tests.util import async_wait_until, wait_until
Expand Down Expand Up @@ -379,6 +380,25 @@ def test_manual_user(self):
chat_interface.send("Test")
assert chat_interface.objects[0].user == "New User"

def test_stream_chat_message(self, chat_interface):
chat_interface.stream(ChatMessage("testeroo", user="useroo", avatar="avataroo"))
chat_message = chat_interface.objects[0]
assert chat_message.user == "useroo"
assert chat_message.avatar == "avataroo"
assert chat_message.object == "testeroo"

def test_stream_chat_message_error_passed_user(self, chat_interface):
with pytest.raises(ValueError, match="Cannot set user or avatar"):
chat_interface.stream(ChatMessage(
"testeroo", user="useroo", avatar="avataroo",
), user="newuser")

def test_stream_chat_message_error_passed_avatar(self, chat_interface):
with pytest.raises(ValueError, match="Cannot set user or avatar"):
chat_interface.stream(ChatMessage(
"testeroo", user="useroo", avatar="avataroo",
), avatar="newavatar")

class TestChatInterfaceWidgetsSizingMode:
def test_none(self):
chat_interface = ChatInterface()
Expand Down
Loading