forked from streamlit/streamlit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LangChain CallbackHandler (streamlit#6890)
Add our LangChain `StreamlitCallbackHandler` (also present in the [LangChain repo](langchain-ai/langchain#6315)), along with some Streamlit-specific tests. When used from LangChain, this callback handler is an "auto-updating API". That is, a LangChain user can do ```python from langchain.callbacks.streamlit import StreamlitCallbackHandler callback = StreamlitCallbackHandler(st.container()) ``` and if they have a recent version of Streamlit installed in their environment, Streamlit's copy of the callback handler will be used instead of the LangChain-internal one. This allows us to update and improve `StreamlitCallbackHandler` independently of LangChain, and LangChain users of the callback will see those changes automatically. In other words, while `StreamlitCallbackHandler` is not part of the public Streamlit `st` API, it _is_ part of LangChain's public API, and we need to keep it stable. (This PR contains a few tests that assert its API stability.)
- Loading branch information
Showing
12 changed files
with
992 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from streamlit.external.langchain.streamlit_callback_handler import ( | ||
LLMThoughtLabeler as LLMThoughtLabeler, | ||
) | ||
from streamlit.external.langchain.streamlit_callback_handler import ( | ||
StreamlitCallbackHandler as StreamlitCallbackHandler, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import Any, NamedTuple, Optional | ||
|
||
from streamlit.delta_generator import DeltaGenerator | ||
from streamlit.type_util import SupportsStr | ||
|
||
|
||
class ChildType(Enum): | ||
MARKDOWN = "MARKDOWN" | ||
EXCEPTION = "EXCEPTION" | ||
|
||
|
||
class ChildRecord(NamedTuple): | ||
type: ChildType | ||
kwargs: dict[str, Any] | ||
dg: DeltaGenerator | ||
|
||
|
||
class MutableExpander: | ||
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed. | ||
Used internally by StreamlitCallbackHandler. | ||
NB: this class's functionality is tested only indirectly by `streamlit_callback_handler_test.py`. | ||
It's not currently intended for use outside of StreamlitCallbackHandler. If it | ||
becomes more broadly useful, we should consider turning it into a "proper" Streamlit | ||
that can support mutations without rebuilding its entire state. | ||
""" | ||
|
||
def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool): | ||
"""Create a new MutableExpander. | ||
Parameters | ||
---------- | ||
parent_container | ||
The `st.container` that the expander will be created inside. | ||
The expander transparently deletes and recreates its underlying | ||
`st.expander` instance when its label changes, and it uses | ||
`parent_container` to ensure it recreates this underlying expander in the | ||
same location onscreen. | ||
label | ||
The expander's initial label. | ||
expanded | ||
The expander's initial `expanded` value. | ||
""" | ||
self._label = label | ||
self._expanded = expanded | ||
self._parent_cursor = parent_container.empty() | ||
self._container = self._parent_cursor.expander(label, expanded) | ||
self._child_records: list[ChildRecord] = [] | ||
|
||
@property | ||
def label(self) -> str: | ||
"""The expander's label string.""" | ||
return self._label | ||
|
||
@property | ||
def expanded(self) -> bool: | ||
"""True if the expander was created with `expanded=True`.""" | ||
return self._expanded | ||
|
||
def clear(self) -> None: | ||
"""Remove the container and its contents entirely. A cleared container can't | ||
be reused. | ||
""" | ||
self._container = self._parent_cursor.empty() | ||
self._child_records.clear() | ||
|
||
def append_copy(self, other: MutableExpander) -> None: | ||
"""Append a copy of another MutableExpander's children to this | ||
MutableExpander. | ||
""" | ||
other_records = other._child_records.copy() | ||
for record in other_records: | ||
self._create_child(record.type, record.kwargs) | ||
|
||
def update( | ||
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None | ||
) -> None: | ||
"""Change the expander's label and expanded state""" | ||
if new_label is None: | ||
new_label = self._label | ||
if new_expanded is None: | ||
new_expanded = self._expanded | ||
|
||
if self._label == new_label and self._expanded == new_expanded: | ||
# No change! | ||
return | ||
|
||
self._label = new_label | ||
self._expanded = new_expanded | ||
self._container = self._parent_cursor.expander(new_label, new_expanded) | ||
|
||
prev_records = self._child_records | ||
self._child_records = [] | ||
|
||
# Replay all children into the new container | ||
for record in prev_records: | ||
self._create_child(record.type, record.kwargs) | ||
|
||
def markdown( | ||
self, | ||
body: SupportsStr, | ||
unsafe_allow_html: bool = False, | ||
*, | ||
help: Optional[str] = None, | ||
index: Optional[int] = None, | ||
) -> int: | ||
"""Add a Markdown element to the container and return its index.""" | ||
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help} | ||
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type] | ||
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg) | ||
return self._add_record(record, index) | ||
|
||
def exception( | ||
self, exception: BaseException, *, index: Optional[int] = None | ||
) -> int: | ||
"""Add an Exception element to the container and return its index.""" | ||
kwargs = {"exception": exception} | ||
new_dg = self._get_dg(index).exception(**kwargs) | ||
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg) | ||
return self._add_record(record, index) | ||
|
||
def _create_child(self, type: ChildType, kwargs: dict[str, Any]) -> None: | ||
"""Create a new child with the given params""" | ||
if type == ChildType.MARKDOWN: | ||
self.markdown(**kwargs) | ||
elif type == ChildType.EXCEPTION: | ||
self.exception(**kwargs) | ||
else: | ||
raise RuntimeError(f"Unexpected child type {type}") | ||
|
||
def _add_record(self, record: ChildRecord, index: Optional[int]) -> int: | ||
"""Add a ChildRecord to self._children. If `index` is specified, replace | ||
the existing record at that index. Otherwise, append the record to the | ||
end of the list. | ||
Return the index of the added record. | ||
""" | ||
if index is not None: | ||
# Replace existing child | ||
self._child_records[index] = record | ||
return index | ||
|
||
# Append new child | ||
self._child_records.append(record) | ||
return len(self._child_records) - 1 | ||
|
||
def _get_dg(self, index: Optional[int]) -> DeltaGenerator: | ||
if index is not None: | ||
# Existing index: reuse child's DeltaGenerator | ||
assert 0 <= index < len(self._child_records), f"Bad index: {index}" | ||
return self._child_records[index].dg | ||
|
||
# No index: use container's DeltaGenerator | ||
return self._container |
Oops, something went wrong.