Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamlitCallbackHandler #6315

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions langchain/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
from langchain.callbacks.streaming_stdout_final_only import (
FinalStreamingStdOutCallbackHandler,
)

# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
# from langchain.callbacks.streamlit import StreamlitCallbackHandler
from langchain.callbacks.streamlit import LLMThoughtLabeler, StreamlitCallbackHandler
from langchain.callbacks.wandb_callback import WandbCallbackHandler
from langchain.callbacks.whylabs_callback import WhyLabsCallbackHandler

Expand All @@ -42,8 +40,8 @@
"OpenAICallbackHandler",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
# now streamlit requires Python >=3.7, !=3.9.7 So, it is commented out here.
# "StreamlitCallbackHandler",
"StreamlitCallbackHandler",
"LLMThoughtLabeler",
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
Expand Down
104 changes: 0 additions & 104 deletions langchain/callbacks/streamlit.py

This file was deleted.

79 changes: 79 additions & 0 deletions langchain/callbacks/streamlit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.streamlit.streamlit_callback_handler import (
LLMThoughtLabeler as LLMThoughtLabeler,
)
from langchain.callbacks.streamlit.streamlit_callback_handler import (
StreamlitCallbackHandler as _InternalStreamlitCallbackHandler,
)

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator


def StreamlitCallbackHandler(
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
) -> BaseCallbackHandler:
"""Construct a new StreamlitCallbackHandler. This CallbackHandler is geared towards
use with a LangChain Agent; it displays the Agent's LLM and tool-usage "thoughts"
inside a series of Streamlit expanders.

Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When this
threshold is reached, a new thought will cause the oldest thoughts to be
collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether that
expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.

Returns
-------
A new StreamlitCallbackHandler instance.

Note that this is an "auto-updating" API: if the installed version of Streamlit
has a more recent StreamlitCallbackHandler implementation, an instance of that class
will be used.

"""
# If we're using a version of Streamlit that implements StreamlitCallbackHandler,
# delegate to it instead of using our built-in handler. The official handler is
# guaranteed to support the same set of kwargs.
try:
from streamlit.external.langchain import (
StreamlitCallbackHandler as OfficialStreamlitCallbackHandler, # type: ignore # noqa: 501
)

return OfficialStreamlitCallbackHandler(
parent_container,
max_thought_containers=max_thought_containers,
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
thought_labeler=thought_labeler,
)
except ImportError:
return _InternalStreamlitCallbackHandler(
parent_container,
max_thought_containers=max_thought_containers,
expand_new_thoughts=expand_new_thoughts,
collapse_completed_thoughts=collapse_completed_thoughts,
thought_labeler=thought_labeler,
)
152 changes: 152 additions & 0 deletions langchain/callbacks/streamlit/mutable_expander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional

if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
from streamlit.type_util import SupportsStr


class ChildType(Enum):
MARKDOWN = "MARKDOWN"
EXCEPTION = "EXCEPTION"


class ChildRecord(NamedTuple):
type: ChildType
kwargs: Dict[str, Any]
dg: DeltaGenerator


class MutableExpander:
"""A Streamlit expander that can be renamed and dynamically expanded/collapsed."""

def __init__(self, parent_container: DeltaGenerator, label: str, expanded: bool):
"""Create a new MutableExpander.

Parameters
----------
parent_container
The `st.container` that the expander will be created inside.

The expander transparently deletes and recreates its underlying
`st.expander` instance when its label changes, and it uses
`parent_container` to ensure it recreates this underlying expander in the
same location onscreen.
label
The expander's initial label.
expanded
The expander's initial `expanded` value.
"""
self._label = label
self._expanded = expanded
self._parent_cursor = parent_container.empty()
self._container = self._parent_cursor.expander(label, expanded)
self._child_records: List[ChildRecord] = []

@property
def label(self) -> str:
"""The expander's label string."""
return self._label

@property
def expanded(self) -> bool:
"""True if the expander was created with `expanded=True`."""
return self._expanded

def clear(self) -> None:
"""Remove the container and its contents entirely. A cleared container can't
be reused.
"""
self._container = self._parent_cursor.empty()
self._child_records.clear()

def append_copy(self, other: MutableExpander) -> None:
"""Append a copy of another MutableExpander's children to this
MutableExpander.
"""
other_records = other._child_records.copy()
for record in other_records:
self._create_child(record.type, record.kwargs)

def update(
self, *, new_label: Optional[str] = None, new_expanded: Optional[bool] = None
) -> None:
"""Change the expander's label and expanded state"""
if new_label is None:
new_label = self._label
if new_expanded is None:
new_expanded = self._expanded

if self._label == new_label and self._expanded == new_expanded:
# No change!
return

self._label = new_label
self._expanded = new_expanded
self._container = self._parent_cursor.expander(new_label, new_expanded)

prev_records = self._child_records
self._child_records = []

# Replay all children into the new container
for record in prev_records:
self._create_child(record.type, record.kwargs)

def markdown(
self,
body: SupportsStr,
unsafe_allow_html: bool = False,
*,
help: Optional[str] = None,
index: Optional[int] = None,
) -> int:
"""Add a Markdown element to the container and return its index."""
kwargs = {"body": body, "unsafe_allow_html": unsafe_allow_html, "help": help}
new_dg = self._get_dg(index).markdown(**kwargs) # type: ignore[arg-type]
record = ChildRecord(ChildType.MARKDOWN, kwargs, new_dg)
return self._add_record(record, index)

def exception(
self, exception: BaseException, *, index: Optional[int] = None
) -> int:
"""Add an Exception element to the container and return its index."""
kwargs = {"exception": exception}
new_dg = self._get_dg(index).exception(**kwargs)
record = ChildRecord(ChildType.EXCEPTION, kwargs, new_dg)
return self._add_record(record, index)

def _create_child(self, type: ChildType, kwargs: Dict[str, Any]) -> None:
"""Create a new child with the given params"""
if type == ChildType.MARKDOWN:
self.markdown(**kwargs)
elif type == ChildType.EXCEPTION:
self.exception(**kwargs)
else:
raise RuntimeError(f"Unexpected child type {type}")

def _add_record(self, record: ChildRecord, index: Optional[int]) -> int:
"""Add a ChildRecord to self._children. If `index` is specified, replace
the existing record at that index. Otherwise, append the record to the
end of the list.

Return the index of the added record.
"""
if index is not None:
# Replace existing child
self._child_records[index] = record
return index

# Append new child
self._child_records.append(record)
return len(self._child_records) - 1

def _get_dg(self, index: Optional[int]) -> DeltaGenerator:
if index is not None:
# Existing index: reuse child's DeltaGenerator
assert 0 <= index < len(self._child_records), f"Bad index: {index}"
return self._child_records[index].dg

# No index: use container's DeltaGenerator
return self._container
Loading