Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async streaming support for OpenAI compatible models #281

Merged
merged 13 commits into from
Aug 23, 2024
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ jobs:
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
LAMINI_API_KEY: ${{ secrets.LAMINI_API_KEY }}
GOOGLE_API_KEY : ${{ secrets.GOOGLE_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}
steps:
- uses: actions/checkout@v4
- name: Install poetry
Expand Down Expand Up @@ -118,6 +119,7 @@ jobs:
echo "All variables are empty"
poetry run pytest -vv tests/ --ignore=tests/test_cli.py
poetry run pytest --llm_provider=anthropic -vv tests/test_magentic.py
poetry run pytest --llm_provider=litellm --openai_compatibility_model=perplexity/llama-3.1-sonar-small-128k-chat -vv tests/test_magentic.py -m chat
fi

- name: Run scheduled llm tests
Expand All @@ -126,3 +128,4 @@ jobs:
echo "This is a schedule event"
poetry run pytest -vv tests/ --ignore=tests/test_cli.py
poetry run pytest --openai_model=gpt-4o -m chat -vv tests/test_openai.py
poetry run pytest --llm_provider=litellm --openai_compatibility_model=perplexity/llama-3.1-sonar-small-128k-chat -vv tests/test_magentic.py -m chat
25 changes: 25 additions & 0 deletions examples/logging/magentic_async_chat_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio

import openai
from magentic import UserMessage, chatprompt
from magentic.chat_model.litellm_chat_model import LitellmChatModel

from log10._httpx_utils import finalize
from log10.load import log10


log10(openai)


async def main(topic: str) -> str:
@chatprompt(
UserMessage(f"Tell me a joke about {topic}"),
model=LitellmChatModel(model="perplexity/llama-3.1-sonar-small-128k-chat"),
)
async def tell_joke(topic: str) -> str: ...

print(await tell_joke(topic))
await finalize()


asyncio.run(main("cats"))
25 changes: 25 additions & 0 deletions examples/logging/perplexity_async_chat_openai_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
import os

import openai
from openai import AsyncOpenAI

from log10._httpx_utils import finalize
from log10.load import log10


log10(openai)

client = AsyncOpenAI(base_url="https://api.perplexity.ai", api_key=os.environ.get("PERPLEXITYAI_API_KEY"))


async def main():
completion = await client.chat.completions.create(
model="llama-3.1-sonar-small-128k-chat",
messages=[{"role": "user", "content": "Say this is a test"}],
)
print(completion.choices[0].message.content)
await finalize()


asyncio.run(main())
108 changes: 85 additions & 23 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import traceback
import uuid
from datetime import datetime, timezone
from enum import Enum

import httpx
from httpx import Request, Response
Expand All @@ -29,6 +30,27 @@
httpx_async_client = httpx.AsyncClient(timeout=timeout)


class LLM_CLIENTS(Enum):
ANTHROPIC = "Anthropic"
OPENAI = "OpenAI"
UNKNOWN = "Unknown"


CLIENT_PATHS = {
LLM_CLIENTS.ANTHROPIC: ["/v1/messages", "/v1/complete"],
# OpenAI and Mistral use the path "v1/chat/completions"
# Perplexity uses the path "chat/completions". Documentation: https://docs.perplexity.ai/reference/post_chat_completions
LLM_CLIENTS.OPENAI: ["v1/chat/completions", "chat/completions"],
}

USER_AGENT_NAME_TO_PROVIDER = {
"AsyncOpenAI": LLM_CLIENTS.OPENAI,
"AsyncAnthropic": LLM_CLIENTS.ANTHROPIC,
"Anthropic": LLM_CLIENTS.ANTHROPIC,
"OpenAI": LLM_CLIENTS.OPENAI,
}


def _get_time_diff(created_at):
time = datetime.fromisoformat(created_at)
now = datetime.now(timezone.utc)
Expand Down Expand Up @@ -225,14 +247,33 @@ def format_anthropic_request(request_content) -> str:
return json.dumps(request_content)


def _get_llm_client(request: Request) -> LLM_CLIENTS:
"""
The request object includes the user-agent header, which is used to identify the LLM client.
For example:
- headers({'user-agent': 'AsyncOpenAI/Python 1.40.6'})
- headers({'user-agent': 'Anthropic/Python 0.34.0'})
"""
user_agent = request.headers.get("user-agent", "")
class_name = user_agent.split("/")[0]

if class_name in ["AsyncAnthropic", "Anthropic"]:
return LLM_CLIENTS.ANTHROPIC
elif class_name in ["AsyncOpenAI", "OpenAI"]:
return LLM_CLIENTS.OPENAI
else:
return LLM_CLIENTS.UNKNOWN
kxtran marked this conversation as resolved.
Show resolved Hide resolved


def _init_log_row(request: Request):
start_time = time.time()
request.started = start_time
orig_module = ""
orig_qualname = ""
request_content_decode = request.content.decode("utf-8")
host = request.headers.get("host")
if "openai" in host:
llm_client = _get_llm_client(request)

if llm_client == LLM_CLIENTS.OPENAI:
if "chat" in str(request.url):
kind = "chat"
orig_module = "openai.api_resources.chat_completion"
Expand All @@ -241,7 +282,7 @@ def _init_log_row(request: Request):
kind = "completion"
orig_module = "openai.api_resources.completion"
orig_qualname = "Completion.create"
elif "anthropic" in host:
elif llm_client == LLM_CLIENTS.ANTHROPIC:
kind = "chat"
url_path = request.url
content_type = request.headers.get("content-type")
Expand All @@ -259,10 +300,10 @@ def _init_log_row(request: Request):
orig_qualname = "Completions.create"

request_content_decode = format_anthropic_request(request_content)

else:
logger.debug("Currently logging is only available for async openai and anthropic.")
return

log_row = {
"status": "started",
"kind": kind,
Expand All @@ -278,15 +319,14 @@ def _init_log_row(request: Request):


def get_completion_id(request: Request):
host = request.headers.get("host")
if "anthropic" in host:
paths = ["/v1/messages", "/v1/complete"]
if not any(path in str(request.url) for path in paths):
logger.debug("Currently logging is only available for anthropic v1/messages and v1/complete.")
return
llm_client = _get_llm_client(request)
if llm_client is LLM_CLIENTS.UNKNOWN:
logger.debug("Currently logging is only available for async openai and anthropic.")
return

if "openai" in host and "v1/chat/completions" not in str(request.url):
logger.debug("Currently logging is only available for openai v1/chat/completions.")
# Check if the request URL matches any of the allowed paths for the class name
if not any(path in str(request.url) for path in CLIENT_PATHS.get(llm_client, [])):
logger.debug(f"Currently logging is only available for {llm_client} {', '.join(CLIENT_PATHS[llm_client])}.")
return

completion_id = str(uuid.uuid4())
Expand Down Expand Up @@ -361,6 +401,10 @@ def log_request(self, request: httpx.Request):

logger.debug("LOG10: sending sync request")
self.log_row = _init_log_row(request)
if not self.log_row:
logger.debug("LOG10: log row is not initialized. Skipping")
return

_try_post_request(url=f"{base_url}/api/completions/{completion_id}", payload=self.log_row)


Expand Down Expand Up @@ -388,6 +432,10 @@ async def log_request(self, request: httpx.Request):

logger.debug("LOG10: sending async request")
self.log_row = _init_log_row(request)
if not self.log_row:
logger.debug("LOG10: log row is not initialized. Skipping")
return

asyncio.create_task(
_try_post_request_async(url=f"{base_url}/api/completions/{completion_id}", payload=self.log_row)
)
Expand All @@ -396,6 +444,8 @@ async def log_request(self, request: httpx.Request):
class _LogResponse(Response):
def __init__(self, *args, **kwargs):
self.log_row = kwargs.pop("log_row")
self.llm_client = _get_llm_client(kwargs.get("request"))
self.host_header = kwargs.get("request").headers.get("host")
super().__init__(*args, **kwargs)

def patch_streaming_log(self, duration: int, full_response: str):
Expand All @@ -410,7 +460,10 @@ def patch_streaming_log(self, duration: int, full_response: str):
for frame in current_stack_frame
]

responses = full_response.split("\n\n")
separator = (
"\r\n\r\n" if self.llm_client == LLM_CLIENTS.OPENAI and "perplexity" in self.host_header else "\n\n"
)
responses = full_response.split(separator)
response_json = self.parse_response_data(responses)

self.log_row["response"] = json.dumps(response_json)
Expand Down Expand Up @@ -452,21 +505,31 @@ async def aiter_bytes(self, *args, **kwargs):
)
yield chunk

def is_response_end_reached(self, text: str):
host = self.request.headers.get("host")
if "anthropic" in host:
def is_response_end_reached(self, text: str) -> bool:
if self.llm_client == LLM_CLIENTS.ANTHROPIC:
return self.is_anthropic_response_end_reached(text)
elif "openai" in host:
return self.is_openai_response_end_reached(text)
elif self.llm_client == LLM_CLIENTS.OPENAI:
if "perplexity" in self.host_header:
return self.is_perplexity_response_end_reached(text)
else:
return self.is_openai_response_end_reached(text)
else:
logger.debug("Currently logging is only available for async openai and anthropic.")
return False

def is_anthropic_response_end_reached(self, text: str):
return "event: message_stop" in text

def is_perplexity_response_end_reached(self, text: str):
json_strings = text.split("data: ")[1:]
# Parse the last JSON string
last_json_str = json_strings[-1].strip()
last_object = json.loads(last_json_str)
return last_object.get("choices", [{}])[0].get("finish_reason", "") == "stop"

def is_openai_response_end_reached(self, text: str):
return "data: [DONE]" in text
# For perplexity, the last item in the responses is empty
return "data: [DONE]" in text or not text
kxtran marked this conversation as resolved.
Show resolved Hide resolved

def parse_anthropic_responses(self, responses: list[str]):
message_id = ""
Expand Down Expand Up @@ -631,11 +694,10 @@ def parse_openai_responses(self, responses: list[str]):
return response_json

def parse_response_data(self, responses: list[str]):
host = self.request.headers.get("host")
if "openai" in host:
return self.parse_openai_responses(responses)
elif "anthropic" in host:
if self.llm_client == LLM_CLIENTS.ANTHROPIC:
return self.parse_anthropic_responses(responses)
elif self.llm_client == LLM_CLIENTS.OPENAI:
return self.parse_openai_responses(responses)
else:
logger.debug("Currently logging is only available for async openai and anthropic.")
return None
Expand Down
Loading
Loading