Skip to content

Commit

Permalink
LangChain CallbackHandler (streamlit#6890)
Browse files Browse the repository at this point in the history
Add our LangChain `StreamlitCallbackHandler` (also present in the [LangChain repo](langchain-ai/langchain#6315)), along with some Streamlit-specific tests.

When used from LangChain, this callback handler is an "auto-updating API". That is, a LangChain user can do

```python
from langchain.callbacks.streamlit import StreamlitCallbackHandler
callback = StreamlitCallbackHandler(st.container())
```

and if they have a recent version of Streamlit installed in their environment, Streamlit's copy of the callback handler will be used instead of the LangChain-internal one. This allows us to update and improve `StreamlitCallbackHandler` independently of LangChain, and LangChain users of the callback will see those changes automatically.

In other words, while `StreamlitCallbackHandler` is not part of the public Streamlit `st` API, it _is_ part of LangChain's public API, and we need to keep it stable. (This PR contains a few tests that assert its API stability.)
  • Loading branch information
tconkling authored and Your Name committed Mar 22, 2024
1 parent 2bf26d5 commit f508009
Show file tree
Hide file tree
Showing 12 changed files with 992 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/min-constraints-gen.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pympler==0.9
python-dateutil==2.7.3
requests==2.18
rich==10.14.0
tenacity==8.0.0
tenacity==8.1.0
toml==0.10.1
tornado==6.0.3
typing-extensions==4.1.0
Expand Down
2 changes: 1 addition & 1 deletion lib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"python-dateutil>=2.7.3, <3",
"requests>=2.18, <3",
"rich>=10.14.0, <14",
"tenacity>=8.0.0, <9",
"tenacity>=8.1.0, <9",
"toml>=0.10.1, <2",
"typing-extensions>=4.1.0, <5",
"tzlocal>=1.1, <5",
Expand Down
13 changes: 13 additions & 0 deletions lib/streamlit/external/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
20 changes: 20 additions & 0 deletions lib/streamlit/external/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from streamlit.external.langchain.streamlit_callback_handler import (
LLMThoughtLabeler as LLMThoughtLabeler,
)
from streamlit.external.langchain.streamlit_callback_handler import (
StreamlitCallbackHandler as StreamlitCallbackHandler,
)
172 changes: 172 additions & 0 deletions lib/streamlit/external/langchain/mutable_expander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from enum import Enum
from typing import Any, NamedTuple, Optional

from streamlit.delta_generator import DeltaGenerator
from streamlit.type_util import SupportsStr


class ChildType(Enum):
MARKDOWN = "MARKDOWN"
EXCEPTION = "EXCEPTION"


class ChildRecord(NamedTuple):
type: ChildType
kwargs: dict[str, Any]
dg: DeltaGenerator


class MutableExpander:
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed.
Used internally by StreamlitCallbackHandler.
NB: this class's functionality is tested only indirectly by `streamlit_callback_handler_test.py`.
It's not currently intended for use outside of StreamlitCallbackHandler. If it
becomes more broadly useful, we should consider turning it into a "proper" Streamlit
that can support mutations without rebuilding its entire state.
"""

def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
"""Create a new MutableExpander.
Parameters
----------
parent_container
The `st.container` that the expander will be created inside.
The expander transparently deletes and recreates its underlying
`st.expander` instance when its label changes, and it uses
`parent_container` to ensure it recreates this underlying expander in the
same location onscreen.
label
The expander's initial label.
expanded
The expander's initial `expanded` value.
"""
self._label = label
self._expanded = expanded
self._parent_cursor = parent_container.empty()
self._container = self._parent_cursor.expander(label, expanded)
self._child_records: list[ChildRecord] = []

@property
def label(self) -> str:
"""The expander's label string."""
return self._label

@property
def expanded(self) -> bool:
"""True if the expander was created with `expanded=True`."""
return self._expanded

def clear(self) -> None:
"""Remove the container and its contents entirely. A cleared container can't
be reused.
"""
self._container = self._parent_cursor.empty()
self._child_records.clear()

def append_copy(self, other: MutableExpander) -> None:
"""Append a copy of another MutableExpander's children to this
MutableExpander.
"""
other_records = other._child_records.copy()
for record in other_records:
self._create_child(record.type, record.kwargs)

def update(
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
) -> None:
"""Change the expander's label and expanded state"""
if new_label is None:
new_label = self._label
if new_expanded is None:
new_expanded = self._expanded

if self._label == new_label and self._expanded == new_expanded:
# No change!
return

self._label = new_label
self._expanded = new_expanded
self._container = self._parent_cursor.expander(new_label, new_expanded)

prev_records = self._child_records
self._child_records = []

# Replay all children into the new container
for record in prev_records:
self._create_child(record.type, record.kwargs)

def markdown(
self,
body: SupportsStr,
unsafe_allow_html: bool = False,
*,
help: Optional[str] = None,
index: Optional[int] = None,
) -> int:
"""Add a Markdown element to the container and return its index."""
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
return self._add_record(record, index)

def exception(
self, exception: BaseException, *, index: Optional[int] = None
) -> int:
"""Add an Exception element to the container and return its index."""
kwargs = {"exception": exception}
new_dg = self._get_dg(index).exception(**kwargs)
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
return self._add_record(record, index)

def _create_child(self, type: ChildType, kwargs: dict[str, Any]) -> None:
"""Create a new child with the given params"""
if type == ChildType.MARKDOWN:
self.markdown(**kwargs)
elif type == ChildType.EXCEPTION:
self.exception(**kwargs)
else:
raise RuntimeError(f"Unexpected child type {type}")

def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
"""Add a ChildRecord to self._children. If `index` is specified, replace
the existing record at that index. Otherwise, append the record to the
end of the list.
Return the index of the added record.
"""
if index is not None:
# Replace existing child
self._child_records[index] = record
return index

# Append new child
self._child_records.append(record)
return len(self._child_records) - 1

def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
if index is not None:
# Existing index: reuse child's DeltaGenerator
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
return self._child_records[index].dg

# No index: use container's DeltaGenerator
return self._container
Loading

0 comments on commit f508009

Please sign in to comment.