Skip to content

Commit

Permalink
implement streaming for assistants (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier authored Jan 17, 2024
1 parent ea9ffeb commit c08a223
Show file tree
Hide file tree
Showing 14 changed files with 381 additions and 144 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ classifiers = [
requires-python = ">=3.9"
dependencies = [
"aiofiles",
"anyio",
"emoji",
"fastapi",
"httpx",
"httpx_sse",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel>=1.3.6,<1.4",
Expand All @@ -38,6 +38,8 @@ dependencies = [
"questionary",
"rich",
"sqlalchemy>=2",
"sse-starlette",
"starlette",
"tomlkit",
"typer",
"uvicorn",
Expand Down
29 changes: 26 additions & 3 deletions ragna/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import builtins
import sys
from typing import Callable, Iterable, Iterator, Mapping, TypeVar

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"]
from typing import (
AsyncIterator,
Awaitable,
Callable,
Iterable,
Iterator,
Mapping,
TypeVar,
)

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions", "anext"]

T = TypeVar("T")

Expand Down Expand Up @@ -38,3 +47,17 @@ def _importlib_metadata_package_distributions() -> (


importlib_metadata_package_distributions = _importlib_metadata_package_distributions()


def _anext() -> Callable[[AsyncIterator[T]], Awaitable[T]]:
if sys.version_info[:2] >= (3, 10):
anext = builtins.anext
else:

async def anext(ait: AsyncIterator[T]) -> T:
return await ait.__anext__()

return anext


anext = _anext()
31 changes: 21 additions & 10 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import cast
import json
from typing import AsyncIterator, cast

import httpx_sse

from ragna.core import RagnaException, Source

Expand Down Expand Up @@ -30,9 +33,11 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
# # See https://docs.anthropic.com/claude/reference/complete_post
response = await self._client.post(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
"POST",
"https://api.anthropic.com/v1/complete",
headers={
"accept": "application/json",
Expand All @@ -45,13 +50,19 @@ async def _call_api(
"prompt": self._instructize_prompt(prompt, sources),
"max_tokens_to_sample": max_new_tokens,
"temperature": 0.0,
"stream": True,
},
)
if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["completion"])
) as event_source:
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
if data["type"] != "completion":
continue
elif "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["stop_reason"] is not None:
break

yield cast(str, data["completion"])


class ClaudeInstant(AnthropicApiAssistant):
Expand Down
14 changes: 9 additions & 5 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import abc
import os
from typing import AsyncIterator

import httpx

import ragna
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source
Expand All @@ -13,8 +16,6 @@ def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR)]

def __init__(self) -> None:
import httpx

self._client = httpx.AsyncClient(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
Expand All @@ -23,11 +24,14 @@ def __init__(self) -> None:

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> str:
return await self._call_api(prompt, sources, max_new_tokens=max_new_tokens)
) -> AsyncIterator[str]:
async for chunk in self._call_api( # type: ignore[attr-defined, misc]
prompt, sources, max_new_tokens=max_new_tokens
):
yield chunk

@abc.abstractmethod
async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
...
7 changes: 4 additions & 3 deletions ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import sys
import textwrap
from typing import Iterator

from ragna.core import Assistant, Source

Expand All @@ -26,11 +27,11 @@ def display_name(cls) -> str:
def max_input_size(self) -> int:
return sys.maxsize

def answer(self, prompt: str, sources: list[Source]) -> str:
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
if re.search("markdown", prompt, re.IGNORECASE):
return self._markdown_answer()
yield self._markdown_answer()
else:
return self._default_answer(prompt, sources)
yield self._default_answer(prompt, sources)

def _markdown_answer(self) -> str:
return textwrap.dedent(
Expand Down
6 changes: 3 additions & 3 deletions ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source

Expand Down Expand Up @@ -29,7 +29,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
instruction = self._instructize_prompt(prompt, sources)
# https://docs.mosaicml.com/en/latest/inference.html#text-completion-requests
response = await self._client.post(
Expand All @@ -47,7 +47,7 @@ async def _call_api(
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["outputs"][0]).replace(instruction, "").strip()
yield cast(str, response.json()["outputs"][0]).replace(instruction, "").strip()


class Mpt7bInstruct(MosaicmlApiAssistant):
Expand Down
33 changes: 19 additions & 14 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import cast
import json
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
import httpx_sse

from ragna.core import Source

from ._api import ApiAssistant

Expand Down Expand Up @@ -29,10 +32,12 @@ def _make_system_content(self, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/object
response = await self._client.post(
# and https://platform.openai.com/docs/api-reference/chat/streaming
async with httpx_sse.aconnect_sse(
self._client,
"POST",
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
Expand All @@ -52,13 +57,16 @@ async def _call_api(
"model": self._MODEL,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"stream": True,
},
)
if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["choices"][0]["message"]["content"])
) as event_source:
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])


class Gpt35Turbo16k(OpenaiApiAssistant):
Expand All @@ -73,9 +81,6 @@ class Gpt35Turbo16k(OpenaiApiAssistant):
_CONTEXT_SIZE = 16_384


Gpt35Turbo16k.__doc__ = "OOPS"


class Gpt4(OpenaiApiAssistant):
"""[OpenAI GPT-4](https://platform.openai.com/docs/models/gpt-4)
Expand Down
61 changes: 54 additions & 7 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import enum
import functools
import inspect
from typing import Type
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
import pydantic.utils
Expand Down Expand Up @@ -138,11 +138,10 @@ class MessageRole(enum.Enum):
ASSISTANT = "assistant"


class Message(pydantic.BaseModel):
class Message:
"""Data class for messages.
Attributes:
content: The content of the message.
role: The message producer.
sources: The sources used to produce the message.
Expand All @@ -152,13 +151,61 @@ class Message(pydantic.BaseModel):
- [ragna.core.Chat.answer][]
"""

content: str
role: MessageRole
sources: list[Source] = pydantic.Field(default_factory=list)
def __init__(
self,
content: Union[str, AsyncIterable[str]],
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
else:
self._content_stream: AsyncIterable[str] = content

self.role = role
self.sources = sources or []

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
return

chunks = []
async for chunk in self._content_stream:
chunks.append(chunk)
yield chunk

self._content = "".join(chunks)

async def read(self) -> str:
if not hasattr(self, "_content"):
# Since self.__aiter__ is already setting the self._content attribute, we
# only need to exhaust the content stream here.
async for _ in self:
pass
return self._content

@property
def content(self) -> str:
if not hasattr(self, "_content"):
raise RuntimeError(
"Message content cannot be accessed without having iterated over it, "
"e.g. `async for chunk in message`, or reading the content, e.g. "
"`await message.read()`, first."
)
return self._content

def __str__(self) -> str:
return self.content

def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"content={self.content}, role={self.role}, sources={self.sources}"
f")"
)


class Assistant(Component, abc.ABC):
"""Abstract base class for assistants used in [ragna.core.Chat][]"""
Expand All @@ -171,7 +218,7 @@ def max_input_size(self) -> int:
...

@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> str:
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
"""Answer a prompt given some sources.
Args:
Expand Down
Loading

0 comments on commit c08a223

Please sign in to comment.