Skip to content

Commit

Permalink
add standardized stream handling
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Sep 2, 2024
1 parent 9d48740 commit 9961c1d
Showing 1 changed file with 37 additions and 28 deletions.
65 changes: 37 additions & 28 deletions backend/danswer/llm/answering/answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_generator_to_string_generator
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
Expand Down Expand Up @@ -311,19 +310,31 @@ def _raw_output_for_explicit_tool_calling_llms(
yield tool_runner.tool_final_result()

prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
for token in message_generator_to_string_generator(
self.llm.stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
):
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)

yield cast(str, token)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)

return

def _process_llm_stream(self, prompt, tools):
for message in self.llm.stream(
prompt=prompt,
tools=tools,
):
if isinstance(message, AIMessageChunk):
if message.content:
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield cast(str, message.content)

if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(stop_reason=StreamStopReason.CONTEXT_LENGTH)

def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
Expand Down Expand Up @@ -401,14 +412,10 @@ def _raw_output_for_non_explicit_tool_calling_llms(
)
)
prompt = prompt_builder.build()
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield token

return
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)

tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
Expand Down Expand Up @@ -461,12 +468,8 @@ def _raw_output_for_non_explicit_tool_calling_llms(
yield final

prompt = prompt_builder.build()
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return
yield token

yield from self._process_llm_stream(self, prompt=prompt)

@property
def processed_streamed_output(self) -> AnswerStream:
Expand Down Expand Up @@ -539,11 +542,17 @@ def _process_stream(
)

def _stream() -> Iterator[str | StreamStopInfo]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
yield from (
cast(str | StreamStopInfo, message)
if message
else (cast(str | StreamStopInfo, item) for item in stream)
)

yield from process_answer_stream_fn(_stream())
for item in _stream():
if isinstance(item, StreamStopInfo):
yield item
else:
yield from process_answer_stream_fn(iter([item]))

processed_stream = []
for processed_packet in _process_stream(output_generator):
Expand Down

0 comments on commit 9961c1d

Please sign in to comment.