Skip to content

Commit

Permalink
fix: OpenAIChatGenerator and OpenAIGenerator crashing when stream…
Browse files Browse the repository at this point in the history
…ing with usage tracking (#8558)

* Fix OpenAIGenerator crashing with tracking usage with streaming enabled

* Fix OpenAIChatGenerator crashing with tracking usage with streaming enabled

* Add release notes

* Fix linting
  • Loading branch information
silvanocerza authored Nov 20, 2024
1 parent 3a30ee3 commit 3ef8c08
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 29 deletions.
38 changes: 21 additions & 17 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,19 +222,19 @@ def run(
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
completion_chunk = None
_first_token = True

# pylint: disable=not-an-iterable
for chunk in chat_completion:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
for completion_chunk in chat_completion:
if completion_chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(completion_chunk)
if _first_token:
_first_token = False
chunk_delta.meta["completion_start_time"] = datetime.now().isoformat()
chunks.append(chunk_delta)
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
completions = [self._create_message_from_chunks(completion_chunk, chunks)]
# if streaming is disabled, the completion is a ChatCompletion
elif isinstance(chat_completion, ChatCompletion):
completions = [self._build_message(chat_completion, choice) for choice in chat_completion.choices]
Expand All @@ -245,18 +245,20 @@ def run(

return {"replies": completions}

def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
def _create_message_from_chunks(
self, completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all chunks returned by the OpenAI API.
:param completion_chunk: The last completion chunk returned by the OpenAI API.
:param streamed_chunks: The list of all chunks returned by the OpenAI API.
"""
is_tools_call = bool(chunks[0].meta.get("tool_calls"))
is_function_call = bool(chunks[0].meta.get("function_call"))
is_tools_call = bool(streamed_chunks[0].meta.get("tool_calls"))
is_function_call = bool(streamed_chunks[0].meta.get("function_call"))
# if it's a tool call or function call, we need to build the payload dict from all the chunks
if is_tools_call or is_function_call:
tools_len = 1 if is_function_call else len(chunks[0].meta.get("tool_calls", []))
tools_len = 1 if is_function_call else len(streamed_chunks[0].meta.get("tool_calls", []))
# don't change this approach of building payload dicts, otherwise mypy will complain
p_def: Dict[str, Any] = {
"index": 0,
Expand All @@ -265,7 +267,7 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
"type": "function",
}
payloads = [copy.deepcopy(p_def) for _ in range(tools_len)]
for chunk_payload in chunks:
for chunk_payload in streamed_chunks:
if is_tools_call:
deltas = chunk_payload.meta.get("tool_calls") or []
else:
Expand All @@ -287,16 +289,18 @@ def _connect_chunks(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessa
else:
total_content = ""
total_meta = {}
for streaming_chunk in chunks:
for streaming_chunk in streamed_chunks:
total_content += streaming_chunk.content
total_meta.update(streaming_chunk.meta)
complete_response = ChatMessage.from_assistant(total_content, meta=total_meta)
finish_reason = streamed_chunks[-1].meta["finish_reason"]
complete_response.meta.update(
{
"model": chunk.model,
"model": completion_chunk.model,
"index": 0,
"finish_reason": chunk.choices[0].finish_reason,
"usage": {}, # we don't have usage data for streaming responses
"finish_reason": finish_reason,
# Usage is available when streaming only if the user explicitly requests it
"usage": dict(completion_chunk.usage or {}),
}
)
return complete_response
Expand Down
30 changes: 18 additions & 12 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class OpenAIGenerator:
```
"""

def __init__(
def __init__( # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "gpt-4o-mini",
Expand Down Expand Up @@ -222,15 +222,17 @@ def run(
if num_responses > 1:
raise ValueError("Cannot stream multiple responses, please set n=1.")
chunks: List[StreamingChunk] = []
chunk = None
completion_chunk: Optional[ChatCompletionChunk] = None

# pylint: disable=not-an-iterable
for chunk in completion:
if chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(chunk)
for completion_chunk in completion:
if completion_chunk.choices and streaming_callback:
chunk_delta: StreamingChunk = self._build_chunk(completion_chunk)
chunks.append(chunk_delta)
streaming_callback(chunk_delta) # invoke callback with the chunk_delta
completions = [self._connect_chunks(chunk, chunks)]
# Makes type checkers happy
assert completion_chunk is not None
completions = [self._create_message_from_chunks(completion_chunk, chunks)]
elif isinstance(completion, ChatCompletion):
completions = [self._build_message(completion, choice) for choice in completion.choices]

Expand All @@ -244,17 +246,21 @@ def run(
}

@staticmethod
def _connect_chunks(chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
def _create_message_from_chunks(
completion_chunk: ChatCompletionChunk, streamed_chunks: List[StreamingChunk]
) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
Creates a single ChatMessage from the streamed chunks. Some data is retrieved from the completion chunk.
"""
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in chunks]))
complete_response = ChatMessage.from_assistant("".join([chunk.content for chunk in streamed_chunks]))
finish_reason = streamed_chunks[-1].meta["finish_reason"]
complete_response.meta.update(
{
"model": chunk.model,
"model": completion_chunk.model,
"index": 0,
"finish_reason": chunk.choices[0].finish_reason,
"usage": {}, # we don't have usage data for streaming responses
"finish_reason": finish_reason,
# Usage is available when streaming only if the user explicitly requests it
"usage": dict(completion_chunk.usage or {}),
}
)
return complete_response
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fix `OpenAIChatGenerator` and `OpenAIGenerator` crashing when using a `streaming_callback` and `generation_kwargs` contain `{"stream_options": {"include_usage": True}}`.
31 changes: 31 additions & 0 deletions test/components/generators/chat/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,3 +329,34 @@ def __call__(self, chunk: StreamingChunk) -> None:

assert callback.counter > 1
assert "Paris" in callback.responses

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming_with_include_usage(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0

def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""

callback = Callback()
component = OpenAIChatGenerator(
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
)
results = component.run([ChatMessage.from_user("What's the capital of France?")])

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
assert "Paris" in message.content

assert "gpt-4o-mini" in message.meta["model"]
assert message.meta["finish_reason"] == "stop"

assert callback.counter > 1
assert "Paris" in callback.responses
39 changes: 39 additions & 0 deletions test/components/generators/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,42 @@ def test_run_with_system_prompt(self):
system_prompt="You answer in German, regardless of the language on which a question is asked.",
)
assert "pythagoras" in result["replies"][0].lower()

@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
@pytest.mark.integration
def test_live_run_streaming_with_include_usage(self):
class Callback:
def __init__(self):
self.responses = ""
self.counter = 0

def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
self.responses += chunk.content if chunk.content else ""

callback = Callback()
component = OpenAIGenerator(
streaming_callback=callback, generation_kwargs={"stream_options": {"include_usage": True}}
)
results = component.run("What's the capital of France?")

assert len(results["replies"]) == 1
assert len(results["meta"]) == 1
response: str = results["replies"][0]
assert "Paris" in response

metadata = results["meta"][0]

assert "gpt-4o-mini" in metadata["model"]
assert metadata["finish_reason"] == "stop"

assert "usage" in metadata
assert "prompt_tokens" in metadata["usage"] and metadata["usage"]["prompt_tokens"] > 0
assert "completion_tokens" in metadata["usage"] and metadata["usage"]["completion_tokens"] > 0
assert "total_tokens" in metadata["usage"] and metadata["usage"]["total_tokens"] > 0

assert callback.counter > 1
assert "Paris" in callback.responses

0 comments on commit 3ef8c08

Please sign in to comment.