Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use JSONL streaming over SSE #357

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,41 +141,38 @@ def answer(self, prompt, sources):
client.post(f"/chats/{chat['id']}/prepare").raise_for_status()

# %%
# Streaming the response is performed with
# [server-sent events (SSE)](https://en.wikipedia.org/wiki/Server-sent_events).
# Streaming the response is performed with [JSONL](https://jsonlines.org/). Each line
# in the response is valid JSON and corresponds to one chunk.

import httpx_sse
import json

chunks = []
with httpx_sse.connect_sse(
client,

with client.stream(
"POST",
f"/chats/{chat['id']}/answer",
json={"prompt": "What is Ragna?", "stream": True},
) as event_source:
for sse in event_source.iter_sse():
chunks.append(json.loads(sse.data))
) as response:
chunks = [json.loads(data) for data in response.iter_lines()]

# %%
# The first event contains the full message object including the sources along the first
# The first chunk contains the full message object including the sources along the first
# chunk of the content.

print(len(chunks))
print(json.dumps(chunks[0], indent=2))

# %%
# Subsequent events no longer contain the sources.
# Subsequent chunks no longer contain the sources.

print(json.dumps(chunks[1], indent=2))

# %%
# Joining the chunks together results in the full message.
# Joining the content of the chunks together results in the full message.

print("".join(chunk["content"] for chunk in chunks))

# %%
# Before we close the tutorial, let's stop the REST API and have a look at what would
# Before we close the example, let's stop the REST API and have a look at what would
# have printed in the terminal if we had started it the regular way.

rest_api.stop()
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ dependencies = [
"emoji",
"fastapi",
"httpx",
"httpx-sse",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel==1.3.8",
Expand All @@ -38,7 +37,6 @@ dependencies = [
"questionary",
"rich",
"sqlalchemy>=2",
"sse-starlette",
"starlette",
"tomlkit",
"typer",
Expand All @@ -56,6 +54,7 @@ Repository = "https://github.com/Quansight/ragna"
# to update the array below, run scripts/update_optional_dependencies.py
all = [
"chromadb>=0.4.13",
"httpx_sse",
"ijson",
"lancedb>=0.2",
"pyarrow",
Expand Down
18 changes: 15 additions & 3 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
from typing import AsyncIterator, cast

import httpx_sse

from ragna.core import RagnaException, Source
from ragna.core import PackageRequirement, RagnaException, Requirement, Source

from ._api import ApiAssistant

Expand All @@ -12,6 +10,10 @@ class AnthropicApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "ANTHROPIC_API_KEY"
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("httpx_sse")]

@classmethod
def display_name(cls) -> str:
return f"Anthropic/{cls._MODEL}"
Expand All @@ -29,6 +31,8 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
import httpx_sse

# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
Expand Down Expand Up @@ -68,6 +72,10 @@ class ClaudeInstant(AnthropicApiAssistant):
!!! info "Required environment variables"
- `ANTHROPIC_API_KEY`
!!! info "Required packages"
- `httpx_sse`
"""

_MODEL = "claude-instant-1"
Expand All @@ -79,6 +87,10 @@ class Claude(AnthropicApiAssistant):
!!! info "Required environment variables"
- `ANTHROPIC_API_KEY`
!!! info "Required packages"
- `httpx_sse`
"""

_MODEL = "claude-2"
6 changes: 5 additions & 1 deletion ragna/assistants/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@ class ApiAssistant(Assistant):

@classmethod
def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR)]
return [EnvVarRequirement(cls._API_KEY_ENV_VAR), *cls._extra_requirements()]

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return []

def __init__(self) -> None:
self._client = httpx.AsyncClient(
Expand Down
7 changes: 2 additions & 5 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,8 @@ class GoogleApiAssistant(ApiAssistant):
_MODEL: str

@classmethod
def requirements(cls) -> list[Requirement]:
return [
*super().requirements(),
PackageRequirement("ijson"),
]
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("ijson")]

@classmethod
def display_name(cls) -> str:
Expand Down
18 changes: 15 additions & 3 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
from typing import AsyncIterator, cast

import httpx_sse

from ragna.core import Source
from ragna.core import PackageRequirement, Requirement, Source

from ._api import ApiAssistant

Expand All @@ -12,6 +10,10 @@ class OpenaiApiAssistant(ApiAssistant):
_API_KEY_ENV_VAR = "OPENAI_API_KEY"
_MODEL: str

@classmethod
def _extra_requirements(cls) -> list[Requirement]:
return [PackageRequirement("httpx_sse")]

@classmethod
def display_name(cls) -> str:
return f"OpenAI/{cls._MODEL}"
Expand All @@ -28,6 +30,8 @@ def _make_system_content(self, sources: list[Source]) -> str:
async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> AsyncIterator[str]:
import httpx_sse

# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/streaming
async with httpx_sse.aconnect_sse(
Expand Down Expand Up @@ -72,6 +76,10 @@ class Gpt35Turbo16k(OpenaiApiAssistant):
!!! info "Required environment variables"
- `OPENAI_API_KEY`
!!! info "Required packages"
- `httpx_sse`
"""

_MODEL = "gpt-3.5-turbo-16k"
Expand All @@ -83,6 +91,10 @@ class Gpt4(OpenaiApiAssistant):
!!! info "Required environment variables"
- `OPENAI_API_KEY`
!!! info "Required packages"
- `httpx_sse`
"""

_MODEL = "gpt-4"
18 changes: 11 additions & 7 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Annotated, Any, AsyncIterator, Iterator, Type, cast

import aiofiles
import sse_starlette
from fastapi import (
Body,
Depends,
Expand All @@ -15,7 +14,8 @@
status,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel

import ragna
import ragna.core
Expand Down Expand Up @@ -273,7 +273,7 @@ async def answer(

if stream:

async def message_chunks() -> AsyncIterator[sse_starlette.ServerSentEvent]:
async def message_chunks() -> AsyncIterator[BaseModel]:
core_answer_stream = aiter(core_answer)
content_chunk = await anext(core_answer_stream)

Expand All @@ -285,23 +285,27 @@ async def message_chunks() -> AsyncIterator[sse_starlette.ServerSentEvent]:
for source in core_answer.sources
],
)
yield sse_starlette.ServerSentEvent(answer.model_dump_json())
yield answer

# Avoid sending the sources multiple times
answer_chunk = answer.model_copy(update=dict(sources=None))
content_chunks = [answer_chunk.content]
async for content_chunk in core_answer_stream:
content_chunks.append(content_chunk)
answer_chunk.content = content_chunk
yield sse_starlette.ServerSentEvent(answer_chunk.model_dump_json())
yield answer_chunk

with get_session() as session:
answer.content = "".join(content_chunks)
chat.messages.append(answer)
database.update_chat(session, user=user, chat=chat)

return sse_starlette.EventSourceResponse( # type: ignore[return-value]
message_chunks()
async def to_jsonl(models: AsyncIterator[Any]) -> AsyncIterator[str]:
async for model in models:
yield f"{model.model_dump_json()}\n"

return StreamingResponse( # type: ignore[return-value]
to_jsonl(message_chunks())
)
else:
answer = schemas.Message.from_core(core_answer)
Expand Down
10 changes: 4 additions & 6 deletions ragna/deploy/_ui/api_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import emoji
import httpx
import httpx_sse
import param


Expand Down Expand Up @@ -64,14 +63,13 @@ async def get_chats(self):
return json_data

async def answer(self, chat_id, prompt):
async with httpx_sse.aconnect_sse(
self.client,
async with self.client.stream(
"POST",
f"/chats/{chat_id}/answer",
json={"prompt": prompt, "stream": True},
) as event_source:
async for sse in event_source.aiter_sse():
yield self.improve_message(json.loads(sse.data))
) as response:
async for data in response.aiter_lines():
yield self.improve_message(json.loads(data))

async def get_components(self):
return (await self.client.get("/components")).raise_for_status().json()
Expand Down
6 changes: 0 additions & 6 deletions requirements-docker.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ annotated-types==0.6.0
anyio==4.2.0
# via
# httpx
# sse-starlette
# starlette
# watchfiles
asgiref==3.7.2
Expand Down Expand Up @@ -71,7 +70,6 @@ fastapi==0.109.0
# via
# Ragna (pyproject.toml)
# chromadb
# sse-starlette
filelock==3.13.1
# via huggingface-hub
flatbuffers==23.5.26
Expand Down Expand Up @@ -342,13 +340,10 @@ sniffio==1.3.0
# httpx
sqlalchemy==2.0.25
# via Ragna (pyproject.toml)
sse-starlette==1.8.2
# via Ragna (pyproject.toml)
starlette==0.35.1
# via
# Ragna (pyproject.toml)
# fastapi
# sse-starlette
sympy==1.12
# via onnxruntime
tenacity==8.2.3
Expand Down Expand Up @@ -395,7 +390,6 @@ uvicorn==0.26.0
# via
# Ragna (pyproject.toml)
# chromadb
# sse-starlette
uvloop==0.19.0
# via uvicorn
watchfiles==0.21.0
Expand Down
16 changes: 3 additions & 13 deletions tests/deploy/api/test_e2e.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json

import httpx_sse
import pytest
from fastapi.testclient import TestClient

Expand Down Expand Up @@ -33,12 +32,6 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):
with open(document_path, "w") as file:
file.write("!\n")

# Reset starlette_sse AppStatus for each run
# See https://github.com/sysid/sse-starlette/issues/59
from sse_starlette.sse import AppStatus

AppStatus.should_exit_event = None

with TestClient(app(config=config, ignore_unavailable_components=False)) as client:
authenticate(client)

Expand Down Expand Up @@ -104,15 +97,12 @@ def test_e2e(tmp_local_root, multiple_answer_chunks, stream_answer):

prompt = "?"
if stream_answer:
chunks = []
with httpx_sse.connect_sse(
client,
with client.stream(
"POST",
f"/chats/{chat['id']}/answer",
json={"prompt": prompt, "stream": True},
) as event_source:
for sse in event_source.iter_sse():
chunks.append(json.loads(sse.data))
) as response:
chunks = [json.loads(chunk) for chunk in response.iter_lines()]
message = chunks[0]
assert all(chunk["sources"] is None for chunk in chunks[1:])
message["content"] = "".join(chunk["content"] for chunk in chunks)
Expand Down