diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 978ce338..edc5e16a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -51,6 +51,7 @@ jobs: LAMINI_API_KEY: ${{ secrets.LAMINI_API_KEY }} GOOGLE_API_KEY : ${{ secrets.GOOGLE_API_KEY }} PERPLEXITYAI_API_KEY: ${{ secrets.PERPLEXITYAI_API_KEY }} + CEREBRAS_API_KEY: ${{ secrets.CEREBRAS_API_KEY }} steps: - uses: actions/checkout@v4 - name: Install poetry diff --git a/log10/_httpx_utils.py b/log10/_httpx_utils.py index 5fa05620..2068540c 100644 --- a/log10/_httpx_utils.py +++ b/log10/_httpx_utils.py @@ -462,7 +462,8 @@ def patch_streaming_log(self, duration: int, full_response: str): "\r\n\r\n" if self.llm_client == LLM_CLIENTS.OPENAI and "perplexity" in self.host_header else "\n\n" ) responses = full_response.split(separator) - response_json = self.parse_response_data(responses) + filter_responses = [r for r in responses if r] + response_json = self.parse_response_data(filter_responses) self.log_row["response"] = json.dumps(response_json) self.log_row["status"] = "finished" @@ -507,10 +508,7 @@ def is_response_end_reached(self, text: str) -> bool: if self.llm_client == LLM_CLIENTS.ANTHROPIC: return self.is_anthropic_response_end_reached(text) elif self.llm_client == LLM_CLIENTS.OPENAI: - if "perplexity" in self.host_header: - return self.is_perplexity_response_end_reached(text) - else: - return self.is_openai_response_end_reached(text) + return self.is_openai_response_end_reached(text) else: logger.debug("Currently logging is only available for async openai and anthropic.") return False @@ -518,19 +516,44 @@ def is_response_end_reached(self, text: str) -> bool: def is_anthropic_response_end_reached(self, text: str): return "event: message_stop" in text - def is_perplexity_response_end_reached(self, text: str): + def has_response_finished_with_stop_reason(self, text: str, parse_single_data_entry: bool = False): json_strings = text.split("data: ")[1:] # Parse the last JSON string last_json_str = json_strings[-1].strip() - last_object = json.loads(last_json_str) - return last_object.get("choices", [{}])[0].get("finish_reason", "") == "stop" + try: + last_object = json.loads(last_json_str) + except json.JSONDecodeError: + logger.debug(f"Full response: {repr(text)}") + logger.debug(f"Failed to parse the last JSON string: {last_json_str}") + return False + + if choices := last_object.get("choices", []): + choice = choices[0] + else: + return False + + finish_reason = choice.get("finish_reason", "") + content = choice.get("delta", {}).get("content", "") + + if finish_reason == "stop": + return not content if parse_single_data_entry else True + return False - def is_openai_response_end_reached(self, text: str): + def is_openai_response_end_reached(self, text: str, parse_single_data_entry: bool = False): """ - In Perplexity, the last item in the responses is empty. - In OpenAI and Mistral, the last item in the responses is "data: [DONE]". + OpenAI, Mistral response end is reached when the data contains "data: [DONE]\n\n". + Perplexity, Cerebras response end is reached when the last JSON object contains finish_reason == stop. + The parse_single_data_entry argument is used to distinguish between a single data entry and multiple data entries. + The function is called in two contexts: first, to assess whether the entire accumulated response has completed when processing streaming data, and second, to verify if a single response object has finished processing during individual response handling. """ - return not text or "data: [DONE]" in text + hosts = ["openai", "mistral"] + + if any(p in self.host_header for p in hosts): + suffix = "data: [DONE]" + ("" if parse_single_data_entry else "\n\n") + if text.endswith(suffix): + return True + + return self.has_response_finished_with_stop_reason(text, parse_single_data_entry) def parse_anthropic_responses(self, responses: list[str]): message_id = "" @@ -628,7 +651,7 @@ def parse_openai_responses(self, responses: list[str]): finish_reason = "" for r in responses: - if self.is_openai_response_end_reached(r): + if self.is_openai_response_end_reached(r, parse_single_data_entry=True): break # loading the substring of response text after 'data: '. diff --git a/tests/pytest.ini b/tests/pytest.ini index d614c2de..06e3a7a2 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = - --openai_model=gpt-3.5-turbo - --openai_vision_model=gpt-4o + --openai_model=gpt-4o-mini + --openai_vision_model=gpt-4o-mini --anthropic_model=claude-3-haiku-20240307 --anthropic_legacy_model=claude-2.1 --google_model=gemini-1.5-pro-latest diff --git a/tests/test_openai_compatibility.py b/tests/test_openai_compatibility.py index 447db81d..ac08c099 100644 --- a/tests/test_openai_compatibility.py +++ b/tests/test_openai_compatibility.py @@ -11,19 +11,39 @@ log10(openai) -model_name = "llama-3.1-sonar-small-128k-chat" -if "PERPLEXITYAI_API_KEY" not in os.environ: - raise ValueError("Please set the PERPLEXITYAI_API_KEY environment variable.") - -compatibility_config = { - "base_url": "https://api.perplexity.ai", - "api_key": os.environ.get("PERPLEXITYAI_API_KEY"), -} +# Define a fixture that provides parameterized api_key and base_url +@pytest.fixture( + params=[ + { + "model_name": "llama-3.1-sonar-small-128k-chat", + "api_key": "PERPLEXITYAI_API_KEY", + "base_url": "https://api.perplexity.ai", + }, + {"model_name": "open-mistral-nemo", "api_key": "MISTRAL_API_KEY", "base_url": "https://api.mistral.ai/v1"}, + {"model_name": "llama3.1-8b", "api_key": "CEREBRAS_API_KEY", "base_url": "https://api.cerebras.ai/v1"}, + ] +) +def config(request): + api_environment_variable = request.param["api_key"] + if api_environment_variable not in os.environ: + raise ValueError(f"Please set the {api_environment_variable} environment variable.") + + return { + "base_url": request.param["base_url"], + "api_key": request.param["api_key"], + "model_name": request.param["model_name"], + } @pytest.mark.chat -def test_chat(session): +def test_chat(session, config): + compatibility_config = { + "base_url": config["base_url"], + "api_key": os.environ.get(config["api_key"]), + } + model_name = config["model_name"] + client = openai.OpenAI(**compatibility_config) completion = client.chat.completions.create( model=model_name, @@ -46,7 +66,13 @@ def test_chat(session): @pytest.mark.chat -def test_chat_not_given(session): +def test_chat_not_given(session, config): + compatibility_config = { + "base_url": config["base_url"], + "api_key": os.environ.get(config["api_key"]), + } + model_name = config["model_name"] + client = openai.OpenAI(**compatibility_config) completion = client.chat.completions.create( model=model_name, @@ -69,23 +95,13 @@ def test_chat_not_given(session): @pytest.mark.chat @pytest.mark.async_client @pytest.mark.asyncio(scope="module") -async def test_chat_async(session): - client = AsyncOpenAI(**compatibility_config) - completion = await client.chat.completions.create( - model=model_name, - messages=[{"role": "user", "content": "Say this is a test"}], - ) +async def test_chat_async(session, config): + compatibility_config = { + "base_url": config["base_url"], + "api_key": os.environ.get(config["api_key"]), + } + model_name = config["model_name"] - content = completion.choices[0].message.content - assert isinstance(content, str) - await finalize() - _LogAssertion(completion_id=session.last_completion_id(), message_content=content).assert_chat_response() - - -@pytest.mark.chat -@pytest.mark.async_client -@pytest.mark.asyncio(scope="module") -async def test_perplexity_chat_async(session): client = AsyncOpenAI(**compatibility_config) completion = await client.chat.completions.create( model=model_name, @@ -100,7 +116,13 @@ async def test_perplexity_chat_async(session): @pytest.mark.chat @pytest.mark.stream -def test_chat_stream(session): +def test_chat_stream(session, config): + compatibility_config = { + "base_url": config["base_url"], + "api_key": os.environ.get(config["api_key"]), + } + model_name = config["model_name"] + client = openai.OpenAI(**compatibility_config) response = client.chat.completions.create( model=model_name, @@ -111,7 +133,7 @@ def test_chat_stream(session): output = "" for chunk in response: - output += chunk.choices[0].delta.content + output += chunk.choices[0].delta.content or "" _LogAssertion(completion_id=session.last_completion_id(), message_content=output).assert_chat_response() @@ -119,9 +141,14 @@ def test_chat_stream(session): @pytest.mark.async_client @pytest.mark.stream @pytest.mark.asyncio(scope="module") -async def test_chat_async_stream(session): - client = AsyncOpenAI(**compatibility_config) +async def test_chat_async_stream(session, config): + compatibility_config = { + "base_url": config["base_url"], + "api_key": os.environ.get(config["api_key"]), + } + model_name = config["model_name"] + client = AsyncOpenAI(**compatibility_config) output = "" stream = await client.chat.completions.create( model=model_name,