diff --git a/pyproject.toml b/pyproject.toml index 20d485b4..e4ffa2c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,10 @@ classifiers = [ requires-python = ">=3.9" dependencies = [ "aiofiles", - "anyio", "emoji", "fastapi", "httpx", + "httpx_sse", "importlib_metadata>=4.6; python_version<'3.10'", "packaging", "panel>=1.3.6,<1.4", @@ -38,6 +38,8 @@ dependencies = [ "questionary", "rich", "sqlalchemy>=2", + "sse-starlette", + "starlette", "tomlkit", "typer", "uvicorn", diff --git a/ragna/_compat.py b/ragna/_compat.py index 58b1d486..f87da373 100644 --- a/ragna/_compat.py +++ b/ragna/_compat.py @@ -1,7 +1,16 @@ +import builtins import sys -from typing import Callable, Iterable, Iterator, Mapping, TypeVar - -__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"] +from typing import ( + AsyncIterator, + Awaitable, + Callable, + Iterable, + Iterator, + Mapping, + TypeVar, +) + +__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions", "anext"] T = TypeVar("T") @@ -38,3 +47,17 @@ def _importlib_metadata_package_distributions() -> ( importlib_metadata_package_distributions = _importlib_metadata_package_distributions() + + +def _anext() -> Callable[[AsyncIterator[T]], Awaitable[T]]: + if sys.version_info[:2] >= (3, 10): + anext = builtins.anext + else: + + async def anext(ait: AsyncIterator[T]) -> T: + return await ait.__anext__() + + return anext + + +anext = _anext() diff --git a/ragna/assistants/_anthropic.py b/ragna/assistants/_anthropic.py index 0c7b97ce..652bab7f 100644 --- a/ragna/assistants/_anthropic.py +++ b/ragna/assistants/_anthropic.py @@ -1,4 +1,7 @@ -from typing import cast +import json +from typing import AsyncIterator, cast + +import httpx_sse from ragna.core import RagnaException, Source @@ -30,9 +33,11 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: async def _call_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> str: - # # See https://docs.anthropic.com/claude/reference/complete_post - response = await self._client.post( + ) -> AsyncIterator[str]: + # See https://docs.anthropic.com/claude/reference/streaming + async with httpx_sse.aconnect_sse( + self._client, + "POST", "https://api.anthropic.com/v1/complete", headers={ "accept": "application/json", @@ -45,13 +50,19 @@ async def _call_api( "prompt": self._instructize_prompt(prompt, sources), "max_tokens_to_sample": max_new_tokens, "temperature": 0.0, + "stream": True, }, - ) - if response.is_error: - raise RagnaException( - status_code=response.status_code, response=response.json() - ) - return cast(str, response.json()["completion"]) + ) as event_source: + async for sse in event_source.aiter_sse(): + data = json.loads(sse.data) + if data["type"] != "completion": + continue + elif "error" in data: + raise RagnaException(data["error"].pop("message"), **data["error"]) + elif data["stop_reason"] is not None: + break + + yield cast(str, data["completion"]) class ClaudeInstant(AnthropicApiAssistant): diff --git a/ragna/assistants/_api.py b/ragna/assistants/_api.py index 21be7557..cc03650a 100644 --- a/ragna/assistants/_api.py +++ b/ragna/assistants/_api.py @@ -1,5 +1,8 @@ import abc import os +from typing import AsyncIterator + +import httpx import ragna from ragna.core import Assistant, EnvVarRequirement, Requirement, Source @@ -13,8 +16,6 @@ def requirements(cls) -> list[Requirement]: return [EnvVarRequirement(cls._API_KEY_ENV_VAR)] def __init__(self) -> None: - import httpx - self._client = httpx.AsyncClient( headers={"User-Agent": f"{ragna.__version__}/{self}"}, timeout=60, @@ -23,11 +24,14 @@ def __init__(self) -> None: async def answer( self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256 - ) -> str: - return await self._call_api(prompt, sources, max_new_tokens=max_new_tokens) + ) -> AsyncIterator[str]: + async for chunk in self._call_api( # type: ignore[attr-defined, misc] + prompt, sources, max_new_tokens=max_new_tokens + ): + yield chunk @abc.abstractmethod async def _call_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> str: + ) -> AsyncIterator[str]: ... diff --git a/ragna/assistants/_demo.py b/ragna/assistants/_demo.py index 74543c3f..b19ed533 100644 --- a/ragna/assistants/_demo.py +++ b/ragna/assistants/_demo.py @@ -1,6 +1,7 @@ import re import sys import textwrap +from typing import Iterator from ragna.core import Assistant, Source @@ -26,11 +27,11 @@ def display_name(cls) -> str: def max_input_size(self) -> int: return sys.maxsize - def answer(self, prompt: str, sources: list[Source]) -> str: + def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: if re.search("markdown", prompt, re.IGNORECASE): - return self._markdown_answer() + yield self._markdown_answer() else: - return self._default_answer(prompt, sources) + yield self._default_answer(prompt, sources) def _markdown_answer(self) -> str: return textwrap.dedent( diff --git a/ragna/assistants/_mosaicml.py b/ragna/assistants/_mosaicml.py index 655b5e13..52341198 100644 --- a/ragna/assistants/_mosaicml.py +++ b/ragna/assistants/_mosaicml.py @@ -1,4 +1,4 @@ -from typing import cast +from typing import AsyncIterator, cast from ragna.core import RagnaException, Source @@ -29,7 +29,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str: async def _call_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> str: + ) -> AsyncIterator[str]: instruction = self._instructize_prompt(prompt, sources) # https://docs.mosaicml.com/en/latest/inference.html#text-completion-requests response = await self._client.post( @@ -47,7 +47,7 @@ async def _call_api( raise RagnaException( status_code=response.status_code, response=response.json() ) - return cast(str, response.json()["outputs"][0]).replace(instruction, "").strip() + yield cast(str, response.json()["outputs"][0]).replace(instruction, "").strip() class Mpt7bInstruct(MosaicmlApiAssistant): diff --git a/ragna/assistants/_openai.py b/ragna/assistants/_openai.py index 84a79ef7..cfd4d5b6 100644 --- a/ragna/assistants/_openai.py +++ b/ragna/assistants/_openai.py @@ -1,6 +1,9 @@ -from typing import cast +import json +from typing import AsyncIterator, cast -from ragna.core import RagnaException, Source +import httpx_sse + +from ragna.core import Source from ._api import ApiAssistant @@ -29,10 +32,12 @@ def _make_system_content(self, sources: list[Source]) -> str: async def _call_api( self, prompt: str, sources: list[Source], *, max_new_tokens: int - ) -> str: + ) -> AsyncIterator[str]: # See https://platform.openai.com/docs/api-reference/chat/create - # and https://platform.openai.com/docs/api-reference/chat/object - response = await self._client.post( + # and https://platform.openai.com/docs/api-reference/chat/streaming + async with httpx_sse.aconnect_sse( + self._client, + "POST", "https://api.openai.com/v1/chat/completions", headers={ "Content-Type": "application/json", @@ -52,13 +57,16 @@ async def _call_api( "model": self._MODEL, "temperature": 0.0, "max_tokens": max_new_tokens, + "stream": True, }, - ) - if response.is_error: - raise RagnaException( - status_code=response.status_code, response=response.json() - ) - return cast(str, response.json()["choices"][0]["message"]["content"]) + ) as event_source: + async for sse in event_source.aiter_sse(): + data = json.loads(sse.data) + choice = data["choices"][0] + if choice["finish_reason"] is not None: + break + + yield cast(str, choice["delta"]["content"]) class Gpt35Turbo16k(OpenaiApiAssistant): @@ -73,9 +81,6 @@ class Gpt35Turbo16k(OpenaiApiAssistant): _CONTEXT_SIZE = 16_384 -Gpt35Turbo16k.__doc__ = "OOPS" - - class Gpt4(OpenaiApiAssistant): """[OpenAI GPT-4](https://platform.openai.com/docs/models/gpt-4) diff --git a/ragna/core/_components.py b/ragna/core/_components.py index f2196439..5dad8222 100644 --- a/ragna/core/_components.py +++ b/ragna/core/_components.py @@ -4,7 +4,7 @@ import enum import functools import inspect -from typing import Type +from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union import pydantic import pydantic.utils @@ -138,11 +138,10 @@ class MessageRole(enum.Enum): ASSISTANT = "assistant" -class Message(pydantic.BaseModel): +class Message: """Data class for messages. Attributes: - content: The content of the message. role: The message producer. sources: The sources used to produce the message. @@ -152,13 +151,61 @@ class Message(pydantic.BaseModel): - [ragna.core.Chat.answer][] """ - content: str - role: MessageRole - sources: list[Source] = pydantic.Field(default_factory=list) + def __init__( + self, + content: Union[str, AsyncIterable[str]], + *, + role: MessageRole = MessageRole.SYSTEM, + sources: Optional[list[Source]] = None, + ) -> None: + if isinstance(content, str): + self._content: str = content + else: + self._content_stream: AsyncIterable[str] = content + + self.role = role + self.sources = sources or [] + + async def __aiter__(self) -> AsyncIterator[str]: + if hasattr(self, "_content"): + yield self._content + return + + chunks = [] + async for chunk in self._content_stream: + chunks.append(chunk) + yield chunk + + self._content = "".join(chunks) + + async def read(self) -> str: + if not hasattr(self, "_content"): + # Since self.__aiter__ is already setting the self._content attribute, we + # only need to exhaust the content stream here. + async for _ in self: + pass + return self._content + + @property + def content(self) -> str: + if not hasattr(self, "_content"): + raise RuntimeError( + "Message content cannot be accessed without having iterated over it, " + "e.g. `async for chunk in message`, or reading the content, e.g. " + "`await message.read()`, first." + ) + return self._content def __str__(self) -> str: return self.content + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"content={self.content}, role={self.role}, sources={self.sources}" + f")" + ) + class Assistant(Component, abc.ABC): """Abstract base class for assistants used in [ragna.core.Chat][]""" @@ -171,7 +218,7 @@ def max_input_size(self) -> int: ... @abc.abstractmethod - def answer(self, prompt: str, sources: list[Source]) -> str: + def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]: """Answer a prompt given some sources. Args: diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index b1af0739..35c4c8e5 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -1,15 +1,16 @@ from __future__ import annotations import datetime -import functools import inspect import uuid from typing import ( Any, + AsyncIterator, Awaitable, Callable, Generic, Iterable, + Iterator, Optional, Type, TypeVar, @@ -17,8 +18,8 @@ cast, ) -import anyio import pydantic +from starlette.concurrency import iterate_in_threadpool, run_in_threadpool from ._components import Assistant, Component, Message, MessageRole, SourceStorage from ._document import Document, LocalDocument @@ -181,7 +182,7 @@ async def prepare(self) -> Message: self._messages.append(welcome) return welcome - async def answer(self, prompt: str) -> Message: + async def answer(self, prompt: str, *, stream: bool = False) -> Message: """Answer a prompt. Returns: @@ -199,25 +200,19 @@ async def answer(self, prompt: str) -> Message: detail=RagnaException.EVENT, ) - prompt = Message(content=prompt, role=MessageRole.USER) - self._messages.append(prompt) + self._messages.append(Message(content=prompt, role=MessageRole.USER)) + + sources = await self._run(self.source_storage.retrieve, self.documents, prompt) - sources = await self._run( - self.source_storage.retrieve, self.documents, prompt.content - ) answer = Message( - content=await self._run(self.assistant.answer, prompt.content, sources), + content=self._run_gen(self.assistant.answer, prompt, sources), role=MessageRole.ASSISTANT, sources=sources, ) - self._messages.append(answer) + if not stream: + await answer.read() - # FIXME: add error handling - # return ( - # "I'm sorry, but I'm having trouble helping you at this time. " - # "Please retry later. " - # "If this issue persists, please contact your administrator." - # ) + self._messages.append(answer) return answer @@ -261,16 +256,34 @@ def _unpack_chat_params( for fn, model in component_models.items() } - async def _run(self, fn: Callable[..., Union[T, Awaitable[T]]], *args: Any) -> T: + async def _run( + self, fn: Union[Callable[..., T], Callable[..., Awaitable[T]]], *args: Any + ) -> T: kwargs = self._unpacked_params[fn] if inspect.iscoroutinefunction(fn): fn = cast(Callable[..., Awaitable[T]], fn) - return await fn(*args, **kwargs) + coro = fn(*args, **kwargs) else: fn = cast(Callable[..., T], fn) - return await anyio.to_thread.run_sync( - functools.partial(fn, *args, **kwargs) - ) + coro = run_in_threadpool(fn, *args, **kwargs) + + return await coro + + async def _run_gen( + self, + fn: Union[Callable[..., Iterator[T]], Callable[..., AsyncIterator[T]]], + *args: Any, + ) -> AsyncIterator[T]: + kwargs = self._unpacked_params[fn] + if inspect.isasyncgenfunction(fn): + fn = cast(Callable[..., AsyncIterator[T]], fn) + async_gen = fn(*args, **kwargs) + else: + fn = cast(Callable[..., Iterator[T]], fn) + async_gen = iterate_in_threadpool(fn(*args, **kwargs)) + + async for item in async_gen: + yield item async def __aenter__(self) -> Chat: await self.prepare() diff --git a/ragna/deploy/_api/core.py b/ragna/deploy/_api/core.py index 6dfa7bf7..6e160548 100644 --- a/ragna/deploy/_api/core.py +++ b/ragna/deploy/_api/core.py @@ -1,9 +1,10 @@ import contextlib import itertools import uuid -from typing import Annotated, Any, Iterator, Type, cast +from typing import Annotated, Any, AsyncIterator, Iterator, Type, cast import aiofiles +import sse_starlette from fastapi import Body, Depends, FastAPI, Form, HTTPException, Request, UploadFile from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse @@ -224,20 +225,51 @@ async def prepare_chat(user: UserDependency, id: uuid.UUID) -> schemas.Message: @app.post("/chats/{id}/answer") async def answer( - user: UserDependency, id: uuid.UUID, prompt: str + user: UserDependency, + id: uuid.UUID, + prompt: Annotated[str, Body(..., embed=True)], + stream: Annotated[bool, Body(..., embed=True)] = False, ) -> schemas.Message: with get_session() as session: chat = database.get_chat(session, user=user, id=id) chat.messages.append( schemas.Message(content=prompt, role=ragna.core.MessageRole.USER) ) - core_chat = schema_to_core_chat(session, user=user, chat=chat) + core_answer = await core_chat.answer(prompt, stream=True) + + if stream: + message_chunk = schemas.Message( + content="", + role=core_answer.role, + sources=[ + schemas.Source.from_core(source) for source in core_answer.sources + ], + ) + + async def message_chunks() -> AsyncIterator[sse_starlette.ServerSentEvent]: + chunks = [] + async for chunk in core_answer: + chunks.append(chunk) + message_chunk.content = chunk + yield sse_starlette.ServerSentEvent(message_chunk.model_dump_json()) + + with get_session() as session: + answer = message_chunk + answer.content = "".join(chunks) + chat.messages.append(answer) + database.update_chat(session, user=user, chat=chat) + + return sse_starlette.EventSourceResponse( # type: ignore[return-value] + message_chunks() + ) + else: answer = schemas.Message.from_core(await core_chat.answer(prompt)) - chat.messages.append(answer) - database.update_chat(session, user=user, chat=chat) + with get_session() as session: + chat.messages.append(answer) + database.update_chat(session, user=user, chat=chat) return answer diff --git a/ragna/deploy/_ui/api_wrapper.py b/ragna/deploy/_ui/api_wrapper.py index fc5aaa8e..4b5c4386 100644 --- a/ragna/deploy/_ui/api_wrapper.py +++ b/ragna/deploy/_ui/api_wrapper.py @@ -1,7 +1,9 @@ +import json from datetime import datetime import emoji import httpx +import httpx_sse import param @@ -62,17 +64,14 @@ async def get_chats(self): return json_data async def answer(self, chat_id, prompt): - return self.improve_message( - ( - await self.client.post( - f"/chats/{chat_id}/answer", - params={"prompt": prompt}, - timeout=None, - ) - ) - .raise_for_status() - .json() - ) + async with httpx_sse.aconnect_sse( + self.client, + "POST", + f"/chats/{chat_id}/answer", + json={"prompt": prompt, "stream": True}, + ) as event_source: + async for sse in event_source.aiter_sse(): + yield self.improve_message(json.loads(sse.data)) async def get_components(self): return (await self.client.get("/components")).raise_for_status().json() diff --git a/ragna/deploy/_ui/central_view.py b/ragna/deploy/_ui/central_view.py index 8614b8f1..8fe5a92e 100644 --- a/ragna/deploy/_ui/central_view.py +++ b/ragna/deploy/_ui/central_view.py @@ -6,6 +6,8 @@ import param from panel.reactive import ReactiveHTML +from ragna._compat import anext + from . import styles as ui # TODO : move all the CSS rules in a dedicated file @@ -91,7 +93,6 @@ class RagnaChatMessage(pn.chat.ChatMessage): role: str = param.Selector(objects=["system", "user", "assistant"]) sources = param.List(allow_None=True) on_click_source_info_callback = param.Callable(allow_None=True) - _content_style_declarations = param.Dict(constant=True) def __init__( self, @@ -104,8 +105,48 @@ def __init__( timestamp=None, show_timestamp=True, ): + css_class = f"message-content-{self.role}" + self.content_pane = pn.pane.Markdown( + content, + css_classes=["message-content", css_class], + stylesheets=ui.stylesheets( + ( + "table", + {"margin-top": "10px", "margin-bottom": "10px"}, + ) + ), + ) + + if role == "assistant": + assert sources is not None + css_class = "message-content-assistant-with-buttons" + object = pn.Column( + self.content_pane, + self._copy_and_source_view_buttons(), + css_classes=[css_class], + ) + else: + object = self.content_pane + + object.stylesheets.extend( + ui.stylesheets( + ( + f":host(.{css_class})", + {"background-color": "rgb(243, 243, 243) !important"} + if role == "user" + else { + "background-color": "none", + "border": "rgb(234, 234, 234)", + "border-style": "solid", + "border-width": "1.2px", + "border-radius": "5px", + }, + ) + ), + ) + super().__init__( - object=content, + object=object, role=role, user=user, sources=sources, @@ -116,42 +157,13 @@ def __init__( show_user=False, show_copy_icon=False, css_classes=[f"message-{role}"], - renderers=[self._render], - _content_style_declarations={ - "background-color": "rgb(243, 243, 243) !important" - } - if role == "user" - else { - "background-color": "none", - "border": "rgb(234, 234, 234)", - "border-style": "solid", - "border-width": "1.2px", - "border-radius": "5px", - }, ) self._stylesheets.extend(message_stylesheets) - if self.sources: - self._update_object_pane() - - def _update_object_pane(self, event=None): - super()._update_object_pane(event) - if self.sources: - assert self.role == "assistant" - css_class = "message-content-assistant-with-buttons" - self._object_panel = self._center_row[0] = pn.Column( - self._object_panel, - self._copy_and_source_view_buttons(), - css_classes=["message", css_class], - stylesheets=ui.stylesheets( - (f":host(.{css_class})", self._content_style_declarations) - ), - ) - def _copy_and_source_view_buttons(self) -> pn.Row: return pn.Row( CopyToClipboardButton( - value=self.object, + value=self.content_pane.object, title="Copy", stylesheets=[ ui.CHAT_INTERFACE_CUSTOM_BUTTON, @@ -193,25 +205,6 @@ def avatar_lookup(self, user: str) -> str: return model[0].upper() - def _render(self, content: str) -> pn.pane.Markdown: - class_selectors = [ - ( - "table", - {"margin-top": "10px", "margin-bottom": "10px"}, - ) - ] - if self.role != "assistant": - # The styling for the assistant messages is applied self._update_object_pane - # since it needs to apply to the content as well as the buttons. - class_selectors.append( - (":host(.message-content)", self._content_style_declarations) - ) - return pn.pane.Markdown( - content, - css_classes=["message-content", f"message-content-{self.role}"], - stylesheets=ui.stylesheets(*class_selectors), - ) - class RagnaChatInterface(pn.chat.ChatInterface): get_user_from_role = param.Callable(allow_None=True) @@ -370,15 +363,21 @@ async def chat_callback( self, content: str, user: str, instance: pn.chat.ChatInterface ): try: - answer = await self.api_wrapper.answer(self.current_chat["id"], content) + answer_stream = self.api_wrapper.answer(self.current_chat["id"], content) + answer = await anext(answer_stream) - yield RagnaChatMessage( + message = RagnaChatMessage( answer["content"], role="assistant", user=self.get_user_from_role("assistant"), sources=answer["sources"], on_click_source_info_callback=self.on_click_source_info_wrapper, ) + yield message + + async for chunk in answer_stream: + message.content_pane.object += chunk["content"] + except Exception: yield RagnaChatMessage( ( diff --git a/tests/core/test_components.py b/tests/core/test_components.py new file mode 100644 index 00000000..64fc398b --- /dev/null +++ b/tests/core/test_components.py @@ -0,0 +1,83 @@ +import asyncio +import functools + +import pytest + +from ragna.core import Message + + +def sync(async_test_fn): + @functools.wraps(async_test_fn) + def wrapper(*args, **kwargs): + return asyncio.run(async_test_fn(*args, **kwargs)) + + return wrapper + + +class TestMessage: + def test_fixed_content(self): + content = "content" + message = Message(content) + + assert message.content == content + assert str(message) == content + + @sync + async def test_fixed_content_read(self): + content = "content" + message = Message(content) + + assert (await message.read()) == content + + @sync + async def test_fixed_content_iter(self): + content = "content" + message = Message(content) + + chunks = [] + async for chunk in message: + chunks.append(chunk) + assert chunks == [content] + + def make_content_stream(self, *chunks): + async def content_stream(): + for chunk in chunks: + yield chunk + + return content_stream() + + @pytest.mark.parametrize( + "content_access", + [ + pytest.param(lambda message: message.content, id="property"), + str, + repr, + ], + ) + def test_stream_content_access_error(self, content_access): + content = "content" + message = Message(self.make_content_stream(*content)) + + with pytest.raises(RuntimeError): + content_access(message) + + @sync + async def test_stream_content_iter(self): + content = "content" + message = Message(self.make_content_stream(*content)) + + chunks = [] + async for chunk in message: + chunks.append(chunk) + assert chunks == list(content) + + assert message.content == content + + @sync + async def test_stream_content_read(self): + content = "content" + message = Message(self.make_content_stream(*content)) + + assert (await message.read()) == content + + assert message.content == content diff --git a/tests/deploy/api/test_e2e.py b/tests/deploy/api/test_e2e.py index a1c69da9..04994c8e 100644 --- a/tests/deploy/api/test_e2e.py +++ b/tests/deploy/api/test_e2e.py @@ -1,6 +1,8 @@ +import json import os import httpx +import httpx_sse import pytest from ragna._utils import timeout_after @@ -10,7 +12,8 @@ @pytest.mark.parametrize("database", ["memory", "sqlite"]) -def test_e2e(tmp_local_root, database): +@pytest.mark.parametrize("stream_answer", [True, False]) +def test_e2e(tmp_local_root, database, stream_answer): if database == "memory": database_url = "memory" elif database == "sqlite": @@ -19,11 +22,11 @@ def test_e2e(tmp_local_root, database): config = Config( local_cache_root=tmp_local_root, api=dict(database_url=database_url) ) - check_api(config) + check_api(config, stream_answer=stream_answer) @timeout_after() -def check_api(config): +def check_api(config, *, stream_answer): document_root = config.local_cache_root / "documents" document_root.mkdir() document_path = document_root / "test.txt" @@ -109,11 +112,26 @@ def check_api(config): assert chat["messages"][-1] == message prompt = "?" - message = ( - client.post(f"/chats/{chat['id']}/answer", params={"prompt": prompt}) - .raise_for_status() - .json() - ) + if stream_answer: + chunks = [] + with httpx_sse.connect_sse( + client, + "POST", + f"/chats/{chat['id']}/answer", + json={"prompt": prompt, "stream": True}, + ) as event_source: + for sse in event_source.iter_sse(): + chunk = json.loads(sse.data) + chunks.append(chunk["content"]) + message = chunk + message["content"] = "".join(chunks) + else: + message = ( + client.post(f"/chats/{chat['id']}/answer", json={"prompt": prompt}) + .raise_for_status() + .json() + ) + assert message["role"] == "assistant" assert {source["document"]["name"] for source in message["sources"]} == { document_path.name