Skip to content

Commit

Permalink
Merge pull request #190 from alan-turing-institute/streaming-response
Browse files Browse the repository at this point in the history
Streaming response
  • Loading branch information
rchan26 authored Jun 12, 2024
2 parents 2899713 + 4a44dab commit bb0a9db
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 89 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ typer = {extras = ["all"], version = "^0.12.3"}
langchain-community = "^0.2.4"
tiktoken = "^0.7.0"
llama-index-embeddings-huggingface = "^0.2.1"
rich = "^13.7.1"


[tool.poetry.group.dev.dependencies]
Expand Down Expand Up @@ -105,10 +106,13 @@ build-backend = "poetry.core.masonry.api"
minversion = "6.0"
testpaths = [
"tests",
"reginald",
]
addopts = """
--cov=estios
--cov=reginald
--cov-report=term:skip-covered
--cov-append
--pdbcls=IPython.terminal.debugger:TerminalPdb
--doctest-modules
"""
doctest_optionflags = ["NORMALIZE_WHITESPACE", "ELLIPSIS",]
25 changes: 22 additions & 3 deletions reginald/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"device": "Device to use (ignored if not using llama-index).",
"api_url": "API URL for the Reginald app.",
"emoji": "Emoji to use for the bot.",
"streaming": "Whether to use streaming for the chat interaction.",
}

cli = typer.Typer()
Expand Down Expand Up @@ -102,6 +103,11 @@ def run_all(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run all the components of the Reginald slack bot.
Establishes the connection to the Slack API, sets up the bot,
and creates a Reginald model to query from.
"""
set_up_logging_config(level=20)
main(
cli="run_all",
Expand Down Expand Up @@ -135,7 +141,7 @@ def bot(
] = EMOJI_DEFAULT,
) -> None:
"""
Main function to run the Slack bot which sets up the bot
Run the Slack bot which sets up the bot
(which uses an API for responding to messages) and
then establishes a WebSocket connection to the
Socket Mode servers and listens for events.
Expand Down Expand Up @@ -213,8 +219,8 @@ def app(
] = DEFAULT_ARGS["device"],
) -> None:
"""
Main function to run the app which sets up the response model
and then creates a FastAPI app to serve the model.
Sets up the response model and then creates a
FastAPI app to serve the model.
The app listens on port 8000 and has two endpoints:
- /direct_message: for obtaining responses from direct messages
Expand Down Expand Up @@ -262,6 +268,9 @@ def create_index(
int, typer.Option(envvar="LLAMA_INDEX_NUM_OUTPUT")
] = DEFAULT_ARGS["num_output"],
) -> None:
"""
Create an index for the Reginald model.
"""
set_up_logging_config(level=20)
main(
cli="create_index",
Expand All @@ -288,6 +297,12 @@ def chat(
Optional[str],
typer.Option(envvar="REGINALD_MODEL_NAME", help=HELP_TEXT["model_name"]),
] = None,
streaming: Annotated[
bool,
typer.Option(
help=HELP_TEXT["streaming"],
),
] = True,
mode: Annotated[
str, typer.Option(envvar="LLAMA_INDEX_MODE", help=HELP_TEXT["mode"])
] = DEFAULT_ARGS["mode"],
Expand Down Expand Up @@ -339,9 +354,13 @@ def chat(
str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"])
] = DEFAULT_ARGS["device"],
) -> None:
"""
Run the chat interaction with the Reginald model.
"""
set_up_logging_config(level=40)
main(
cli="chat",
streaming=streaming,
model=model,
model_name=model_name,
mode=mode,
Expand Down
4 changes: 2 additions & 2 deletions reginald/models/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
}

DEFAULTS = {
"chat-completion-azure": "reginald-curie",
"chat-completion-azure": "reginald-gpt4",
"chat-completion-openai": "gpt-3.5-turbo",
"hello": None,
"llama-index-ollama": "llama3",
"llama-index-llama-cpp": "https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF/resolve/main/llama-2-13b-chat.Q6_K.gguf",
"llama-index-hf": "microsoft/phi-1_5",
"llama-index-gpt-azure": "reginald-gpt35-turbo",
"llama-index-gpt-azure": "reginald-gpt4",
"llama-index-gpt-openai": "gpt-3.5-turbo",
}

Expand Down
4 changes: 4 additions & 0 deletions reginald/models/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@ def __init__(self, emoji: Optional[str], *args: Any, **kwargs: Any):
Emoji to use for the bot's response
"""
self.emoji = emoji
self.mode = "NA"

def direct_message(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError

def channel_mention(self, message: str, user_id: str) -> MessageResponse:
raise NotImplementedError

def stream_message(self, message: str, user_id: str) -> None:
raise NotImplementedError
42 changes: 40 additions & 2 deletions reginald/models/models/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import logging
import os
import sys
from typing import Any

import openai
from openai import AzureOpenAI, OpenAI

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import get_env_var
from reginald.utils import get_env_var, stream_iter_progress_wrapper


class ChatCompletionBase(ResponseModel):
Expand Down Expand Up @@ -155,6 +154,35 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
"""
return self._respond(message=message, user_id=user_id)

def stream_message(self, message: str, user_id: str) -> None:
if self.mode == "chat":
response = self.client.chat.completions.create(
model=self.engine,
messages=[{"role": "user", "content": message}],
frequency_penalty=self.frequency_penalty,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
stop=None,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
)
elif self.mode == "query":
response = self.client.completions.create(
model=self.engine,
frequency_penalty=self.frequency_penalty,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
prompt=message,
stop=None,
temperature=self.temperature,
top_p=self.top_p,
stream=True,
)

for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)


class ChatCompletionOpenAI(ChatCompletionBase):
def __init__(
Expand Down Expand Up @@ -233,3 +261,13 @@ def channel_mention(self, message: str, user_id: str) -> MessageResponse:
Response from the query engine.
"""
return self._respond(message=message, user_id=user_id)

def stream_message(self, message: str, user_id: str) -> None:
response = self.client.chat.completions.create(
model=self.model_name,
messages=[{"role": "user", "content": message}],
stream=True,
)

for chunk in stream_iter_progress_wrapper(response):
print(chunk.choices[0].delta.content, end="", flush=True)
7 changes: 7 additions & 0 deletions reginald/models/models/hello.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

from reginald.models.models.base import MessageResponse, ResponseModel
from reginald.utils import stream_iter_progress_wrapper


class Hello(ResponseModel):
Expand All @@ -16,3 +17,9 @@ def direct_message(self, message: str, user_id: str) -> MessageResponse:

def channel_mention(self, message: str, user_id: str) -> MessageResponse:
return MessageResponse(f"Hello <@{user_id}>")

def stream_message(self, message: str, user_id: str) -> None:
# print("\nReginald: ", end="")
token_list: tuple[str, ...] = ("Hello", "!", " How", " are", " you", "?")
for token in stream_iter_progress_wrapper(token_list):
print(token, end="", flush=True)
Loading

0 comments on commit bb0a9db

Please sign in to comment.