diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 7e7733cc0fe..0a0a1c52afa 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -550,15 +550,21 @@ def _process_stream( answer_style_configs=self.answer_style_config, ) - def _stream() -> Iterator[str | StreamStopInfo]: - yield cast(str | StreamStopInfo, message) - yield from (cast(str | StreamStopInfo, item) for item in stream) + stream_stop_info = None + + def _stream() -> Iterator[str]: + nonlocal stream_stop_info + yield cast(str, message) + for item in stream: + if isinstance(item, StreamStopInfo): + stream_stop_info = item + return + yield cast(str, item) - for item in _stream(): - if isinstance(item, StreamStopInfo): - yield item - else: - yield from process_answer_stream_fn(iter([item])) + yield from process_answer_stream_fn(_stream()) + + if stream_stop_info: + yield stream_stop_info processed_stream = [] for processed_packet in _process_stream(output_generator):