Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize perplexity logic for streaming end detection across OpenAI compatible models #286

Merged
merged 11 commits into from
Aug 29, 2024
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ jobs:
LAMINI_API_KEY: ${{ secrets.LAMINI_API_KEY }}
GOOGLE_API_KEY : ${{ secrets.GOOGLE_API_KEY }}
PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }}
CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }}
steps:
- uses: actions/checkout@v4
- name: Install poetry
Expand Down
49 changes: 36 additions & 13 deletions log10/_httpx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ def patch_streaming_log(self, duration: int, full_response: str):
"\r\n\r\n" if self.llm_client == LLM_CLIENTS.OPENAI and "perplexity" in self.host_header else "\n\n"
)
responses = full_response.split(separator)
response_json = self.parse_response_data(responses)
filter_responses = [r for r in responses if r]
response_json = self.parse_response_data(filter_responses)

self.log_row["response"] = json.dumps(response_json)
self.log_row["status"] = "finished"
Expand Down Expand Up @@ -507,30 +508,52 @@ def is_response_end_reached(self, text: str) -> bool:
if self.llm_client == LLM_CLIENTS.ANTHROPIC:
return self.is_anthropic_response_end_reached(text)
elif self.llm_client == LLM_CLIENTS.OPENAI:
if "perplexity" in self.host_header:
return self.is_perplexity_response_end_reached(text)
else:
return self.is_openai_response_end_reached(text)
return self.is_openai_response_end_reached(text)
else:
logger.debug("Currently logging is only available for async openai and anthropic.")
return False

def is_anthropic_response_end_reached(self, text: str):
return "event: message_stop" in text

def is_perplexity_response_end_reached(self, text: str):
def has_response_finished_with_stop_reason(self, text: str, parse_single_data_entry: bool = False):
json_strings = text.split("data: ")[1:]
# Parse the last JSON string
last_json_str = json_strings[-1].strip()
last_object = json.loads(last_json_str)
return last_object.get("choices", [{}])[0].get("finish_reason", "") == "stop"
try:
last_object = json.loads(last_json_str)
except json.JSONDecodeError:
logger.debug(f"Full response: {repr(text)}")
logger.debug(f"Failed to parse the last JSON string: {last_json_str}")
return False

if choices := last_object.get("choices", []):
choice = choices[0]
else:
return False

finish_reason = choice.get("finish_reason", "")
content = choice.get("delta", {}).get("content", "")

if finish_reason == "stop":
return not content if parse_single_data_entry else True
return False

def is_openai_response_end_reached(self, text: str):
def is_openai_response_end_reached(self, text: str, parse_single_data_entry: bool = False):
"""
In Perplexity, the last item in the responses is empty.
In OpenAI and Mistral, the last item in the responses is "data: [DONE]".
OpenAI, Mistral response end is reached when the data contains "data: [DONE]\n\n".
Perplexity, Cerebras response end is reached when the last JSON object contains finish_reason == stop.
The parse_single_data_entry argument is used to distinguish between a single data entry and multiple data entries.
The function is called in two contexts: first, to assess whether the entire accumulated response has completed when processing streaming data, and second, to verify if a single response object has finished processing during individual response handling.
"""
return not text or "data: [DONE]" in text
hosts = ["openai", "mistral"]

if any(p in self.host_header for p in hosts):
suffix = "data: [DONE]" + ("" if parse_single_data_entry else "\n\n")
if text.endswith(suffix):
return True

return self.has_response_finished_with_stop_reason(text, parse_single_data_entry)

def parse_anthropic_responses(self, responses: list[str]):
message_id = ""
Expand Down Expand Up @@ -628,7 +651,7 @@ def parse_openai_responses(self, responses: list[str]):
finish_reason = ""

for r in responses:
if self.is_openai_response_end_reached(r):
if self.is_openai_response_end_reached(r, parse_single_data_entry=True):
break

# loading the substring of response text after 'data: '.
Expand Down
4 changes: 2 additions & 2 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--openai_model=gpt-3.5-turbo
--openai_vision_model=gpt-4o
--openai_model=gpt-4o-mini
--openai_vision_model=gpt-4o-mini
--anthropic_model=claude-3-haiku-20240307
--anthropic_legacy_model=claude-2.1
--google_model=gemini-1.5-pro-latest
Expand Down
87 changes: 57 additions & 30 deletions tests/test_openai_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,39 @@

log10(openai)

model_name = "llama-3.1-sonar-small-128k-chat"

if "PERPLEXITYAI_API_KEY" not in os.environ:
raise ValueError("Please set the PERPLEXITYAI_API_KEY environment variable.")

compatibility_config = {
"base_url": "https://api.perplexity.ai",
"api_key": os.environ.get("PERPLEXITYAI_API_KEY"),
}
# Define a fixture that provides parameterized api_key and base_url
@pytest.fixture(
params=[
{
"model_name": "llama-3.1-sonar-small-128k-chat",
"api_key": "PERPLEXITYAI_API_KEY",
"base_url": "https://api.perplexity.ai",
},
{"model_name": "open-mistral-nemo", "api_key": "MISTRAL_API_KEY", "base_url": "https://api.mistral.ai/v1"},
kxtran marked this conversation as resolved.
Show resolved Hide resolved
{"model_name": "llama3.1-8b", "api_key": "CEREBRAS_API_KEY", "base_url": "https://api.cerebras.ai/v1"},
]
)
def config(request):
api_environment_variable = request.param["api_key"]
if api_environment_variable not in os.environ:
raise ValueError(f"Please set the {api_environment_variable} environment variable.")

return {
"base_url": request.param["base_url"],
"api_key": request.param["api_key"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I avoided using os.environ.get(request.param["api_key"]) in the fixture. When a test is failed and the api key is loaded in the fixture, the pytest report will print out the api keys.

"model_name": request.param["model_name"],
}


@pytest.mark.chat
def test_chat(session):
def test_chat(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
completion = client.chat.completions.create(
model=model_name,
Expand All @@ -46,7 +66,13 @@ def test_chat(session):


@pytest.mark.chat
def test_chat_not_given(session):
def test_chat_not_given(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
completion = client.chat.completions.create(
model=model_name,
Expand All @@ -69,23 +95,13 @@ def test_chat_not_given(session):
@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio(scope="module")
async def test_chat_async(session):
client = AsyncOpenAI(**compatibility_config)
completion = await client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": "Say this is a test"}],
)
async def test_chat_async(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

content = completion.choices[0].message.content
assert isinstance(content, str)
await finalize()
_LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response()


@pytest.mark.chat
@pytest.mark.async_client
@pytest.mark.asyncio(scope="module")
async def test_perplexity_chat_async(session):
client = AsyncOpenAI(**compatibility_config)
completion = await client.chat.completions.create(
model=model_name,
Expand All @@ -100,7 +116,13 @@ async def test_perplexity_chat_async(session):

@pytest.mark.chat
@pytest.mark.stream
def test_chat_stream(session):
def test_chat_stream(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = openai.OpenAI(**compatibility_config)
response = client.chat.completions.create(
model=model_name,
Expand All @@ -111,17 +133,22 @@ def test_chat_stream(session):

output = ""
for chunk in response:
output += chunk.choices[0].delta.content
output += chunk.choices[0].delta.content or ""

_LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response()


@pytest.mark.async_client
@pytest.mark.stream
@pytest.mark.asyncio(scope="module")
async def test_chat_async_stream(session):
client = AsyncOpenAI(**compatibility_config)
async def test_chat_async_stream(session, config):
compatibility_config = {
"base_url": config["base_url"],
"api_key": os.environ.get(config["api_key"]),
}
model_name = config["model_name"]

client = AsyncOpenAI(**compatibility_config)
output = ""
stream = await client.chat.completions.create(
model=model_name,
Expand Down
Loading