Skip to content

Commit

Permalink
feat: support for tools in OpenAIChatGenerator (#8666)
Browse files Browse the repository at this point in the history
* move chatmsg>openai conversion to chatmsg dataclass

* implementation and tests cleanup

* release note

* try fixing azure chat generator

* add serde test for toolinvoker

* small fix
  • Loading branch information
anakin87 authored Dec 20, 2024
1 parent 7dcbf25 commit 188b2a7
Show file tree
Hide file tree
Showing 17 changed files with 720 additions and 305 deletions.
5 changes: 5 additions & 0 deletions haystack/components/generators/chat/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
self.default_headers = default_headers or {}

# This ChatGenerator does not yet supports tools. The following workaround ensures that we do not
# get an error when invoking the run method of the parent class (OpenAIChatGenerator).
self.tools = None
self.tools_strict = False

self.client = AzureOpenAI(
api_version=api_version,
azure_endpoint=azure_endpoint,
Expand Down
14 changes: 6 additions & 8 deletions haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
msg = f"Unknown api_type {api_type}"
raise ValueError(msg)

if tools:
if streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
if tools and streaming_callback is not None:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

# handle generation kwargs setup
generation_kwargs = generation_kwargs.copy() if generation_kwargs else {}
Expand Down Expand Up @@ -241,10 +240,9 @@ def run(
formatted_messages = [convert_message_to_hf_format(message) for message in messages]

tools = tools or self.tools
if tools:
if self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)
if tools and self.streaming_callback:
raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.")
_check_duplicate_tool_names(tools)

if self.streaming_callback:
return self._run_streaming(formatted_messages, generation_kwargs)
Expand Down
350 changes: 209 additions & 141 deletions haystack/components/generators/chat/openai.py

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions haystack/components/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from openai.types.chat import ChatCompletion, ChatCompletionChunk

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.components.generators.openai_utils import _convert_message_to_openai_format
from haystack.dataclasses import ChatMessage, StreamingChunk
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable

Expand Down Expand Up @@ -207,7 +206,7 @@ def run(
streaming_callback = streaming_callback or self.streaming_callback

# adapt ChatMessage(s) to the format expected by the OpenAI API
openai_formatted_messages = [_convert_message_to_openai_format(message) for message in messages]
openai_formatted_messages = [message.to_openai_dict_format() for message in messages]

completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
model=self.model,
Expand Down
23 changes: 0 additions & 23 deletions haystack/components/generators/openai_utils.py

This file was deleted.

45 changes: 45 additions & 0 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import json
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
Expand Down Expand Up @@ -381,3 +382,47 @@ def from_dict(cls, data: Dict[str, Any]) -> "ChatMessage":
data["_content"] = content

return cls(**data)

def to_openai_dict_format(self) -> Dict[str, Any]:
"""
Convert a ChatMessage to the dictionary format expected by OpenAI's Chat API.
"""
text_contents = self.texts
tool_calls = self.tool_calls
tool_call_results = self.tool_call_results

if not text_contents and not tool_calls and not tool_call_results:
raise ValueError(
"A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`."
)
if len(text_contents) + len(tool_call_results) > 1:
raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.")

openai_msg: Dict[str, Any] = {"role": self._role.value}

if tool_call_results:
result = tool_call_results[0]
if result.origin.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
openai_msg["content"] = result.result
openai_msg["tool_call_id"] = result.origin.id
# OpenAI does not provide a way to communicate errors in tool invocations, so we ignore the error field
return openai_msg

if text_contents:
openai_msg["content"] = text_contents[0]
if tool_calls:
openai_tool_calls = []
for tc in tool_calls:
if tc.id is None:
raise ValueError("`ToolCall` must have a non-null `id` attribute to be used with OpenAI.")
openai_tool_calls.append(
{
"id": tc.id,
"type": "function",
# We disable ensure_ascii so special chars like emojis are not converted
"function": {"name": tc.tool_name, "arguments": json.dumps(tc.arguments, ensure_ascii=False)},
}
)
openai_msg["tool_calls"] = openai_tool_calls
return openai_msg
6 changes: 4 additions & 2 deletions haystack/dataclasses/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,15 @@ def _remove_title_from_schema(schema: Dict[str, Any]):
del property_schema[key]


def _check_duplicate_tool_names(tools: List[Tool]) -> None:
def _check_duplicate_tool_names(tools: Optional[List[Tool]]) -> None:
"""
Check for duplicate tool names and raises a ValueError if they are found.
Checks for duplicate tool names and raises a ValueError if they are found.
:param tools: The list of tools to check.
:raises ValueError: If duplicate tool names are found.
"""
if tools is None:
return
tool_names = [tool.name for tool in tools]
duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1}
if duplicate_tool_names:
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/openai-tools-26f58a981c4066ef.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Add support for Tools in the OpenAI Chat Generator.
14 changes: 0 additions & 14 deletions test/components/generators/chat/conftest.py

This file was deleted.

8 changes: 8 additions & 0 deletions test/components/generators/chat/test_hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ def streaming_callback_handler(x):
return x


@pytest.fixture
def chat_messages():
return [
ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"),
ChatMessage.from_user("Tell me about Berlin"),
]


@pytest.fixture
def model_info_mock():
with patch(
Expand Down
Loading

0 comments on commit 188b2a7

Please sign in to comment.