Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update OpenAPIServiceConnector to new ChatMessage #8817

Merged
merged 10 commits into from
Feb 10, 2025
180 changes: 141 additions & 39 deletions haystack/components/connectors/openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,135 @@
logger = logging.getLogger(__name__)

with LazyImport("Run 'pip install openapi3'") as openapi_imports:
import requests
from openapi3 import OpenAPI
from openapi3.errors import UnexpectedResponseError
from openapi3.paths import Operation

# Patch the request method to add support for the proper raw_response handling
# If you see that https://github.com/Dorthu/openapi3/pull/124/
# is merged, we can remove this patch - notify authors of this code
def patch_request(
self,
base_url: str,
*,
data: Optional[Any] = None,
parameters: Optional[Dict[str, Any]] = None,
raw_response: bool = False,
security: Optional[Dict[str, str]] = None,
session: Optional[Any] = None,
verify: Union[bool, str] = True,
) -> Optional[Any]:
"""
Sends an HTTP request as described by this path.

:param base_url: The URL to append this operation's path to when making
the call.
:param data: The request body to send.
:param parameters: The parameters used to create the path.
:param raw_response: If true, return the raw response instead of validating
and exterpolating it.
:param security: The security scheme to use, and the values it needs to
process successfully.
:param session: A persistent request session.
:param verify: If we should do an ssl verification on the request or not.
In case str was provided, will use that as the CA.
:return: The response data, either raw or processed depending on raw_response flag.
"""
# Set request method (e.g. 'GET')
self._request = requests.Request(self.path[-1])

# Set self._request.url to base_url w/ path
self._request.url = base_url + self.path[-2]

parameters = parameters or {}
security = security or {}

if security and self.security:
security_requirement = None
for scheme, value in security.items():
security_requirement = None
for r in self.security:
if r.name == scheme:
security_requirement = r
self._request_handle_secschemes(r, value)

if security_requirement is None:
err_msg = """No security requirement satisfied (accepts {}) \
""".format(", ".join(self.security.keys()))
raise ValueError(err_msg)

if self.requestBody:
if self.requestBody.required and data is None:
err_msg = "Request Body is required but none was provided."
raise ValueError(err_msg)

self._request_handle_body(data)

self._request_handle_parameters(parameters)

if session is None:
session = self._session

# send the prepared request
result = session.send(self._request.prepare(), verify=verify)

# spec enforces these are strings
status_code = str(result.status_code)

# find the response model in spec we received
expected_response = None
if status_code in self.responses:
expected_response = self.responses[status_code]
elif "default" in self.responses:
expected_response = self.responses["default"]

if expected_response is None:
raise UnexpectedResponseError(result, self)

# if we got back a valid response code (or there was a default) and no
# response content was expected, return None
if expected_response.content is None:
return None

content_type = result.headers["Content-Type"]
if ";" in content_type:
# if the content type that came in included an encoding, we'll ignore
# it for now (requests has already parsed it for us) and only look at
# the MIME type when determining if an expected content type was returned.
content_type = content_type.split(";")[0].strip()

expected_media = expected_response.content.get(content_type, None)

# If raw_response is True, return the raw text or json based on content type
if raw_response:
if "application/json" in content_type:
return result.json()
return result.text

if expected_media is None and "/" in content_type:
# accept media type ranges in the spec. the most specific matching
# type should always be chosen, but if we do not have a match here
# a generic range should be accepted if one if provided
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#response-object

generic_type = content_type.split("/")[0] + "/*"
expected_media = expected_response.content.get(generic_type, None)

if expected_media is None:
err_msg = """Unexpected Content-Type {} returned for operation {} \
(expected one of {})"""
err_var = result.headers["Content-Type"], self.operationId, ",".join(expected_response.content.keys())

raise RuntimeError(err_msg.format(*err_var))

if content_type.lower() == "application/json":
return expected_media.schema.model(result.json())

raise NotImplementedError("Only application/json content type is supported")

# Apply the patch
Operation.request = patch_request


@component
Expand Down Expand Up @@ -89,12 +217,10 @@ def run(
"""
Processes a list of chat messages to invoke a method on an OpenAPI service.

It parses the last message in the list, expecting it to contain an OpenAI function calling descriptor
(name & parameters) in JSON format.
It parses the last message in the list, expecting it to contain tool calls.

:param messages: A list of `ChatMessage` objects containing the messages to be processed. The last message
should contain the function invocation payload in OpenAI function calling format. See the example in the class
docstring for the expected format.
should contain the tool calls.
:param service_openapi_spec: The OpenAPI JSON specification object of the service to be invoked. All the refs
should already be resolved.
:param service_credentials: The credentials to be used for authentication with the service.
Expand All @@ -105,29 +231,34 @@ def run(
response is in JSON format, and the `content` attribute of the `ChatMessage` contains
the JSON string.

:raises ValueError: If the last message is not from the assistant or if it does not contain the correct payload
to invoke a method on the service.
:raises ValueError: If the last message is not from the assistant or if it does not contain tool calls.
"""

last_message = messages[-1]
if not last_message.is_from(ChatRole.ASSISTANT):
raise ValueError(f"{last_message} is not from the assistant.")

function_invocation_payloads = self._parse_message(last_message)
tool_calls = last_message.tool_calls
if not tool_calls:
raise ValueError(f"The provided ChatMessage has no tool calls.\nChatMessage: {last_message}")

function_payloads = []
for tool_call in tool_calls:
function_payloads.append({"arguments": tool_call.arguments, "name": tool_call.tool_name})

# instantiate the OpenAPI service for the given specification
openapi_service = OpenAPI(service_openapi_spec, ssl_verify=self.ssl_verify)
self._authenticate_service(openapi_service, service_credentials)

response_messages = []
for method_invocation_descriptor in function_invocation_payloads:
for method_invocation_descriptor in function_payloads:
service_response = self._invoke_method(openapi_service, method_invocation_descriptor)
# openapi3 parses the JSON service response into a model object, which is not our focus at the moment.
# Instead, we require direct access to the raw JSON data of the response, rather than the model objects
# provided by the openapi3 library. This approach helps us avoid issues related to (de)serialization.
# By accessing the raw JSON response through `service_response._raw_data`, we can serialize this data
# into a string. Finally, we use this string to create a ChatMessage object.
response_messages.append(ChatMessage.from_user(json.dumps(service_response._raw_data)))
response_messages.append(ChatMessage.from_user(json.dumps(service_response)))

return {"service_response": response_messages}

Expand All @@ -152,35 +283,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenAPIServiceConnector":
"""
return default_from_dict(cls, data)

def _parse_message(self, message: ChatMessage) -> List[Dict[str, Any]]:
"""
Parses the message to extract the method invocation descriptor.

:param message: ChatMessage containing the tools calls
:return: A list of function invocation payloads
:raises ValueError: If the content is not valid JSON or lacks required fields.
"""
function_payloads = []
if message.text is None:
raise ValueError(f"The provided ChatMessage has no text.\nChatMessage: {message}")
try:
tool_calls = json.loads(message.text)
except json.JSONDecodeError:
raise ValueError("Invalid JSON content, expected OpenAI tools message.", message.text)

for tool_call in tool_calls:
# this should never happen, but just in case do a sanity check
if "type" not in tool_call:
raise ValueError("Message payload doesn't seem to be a tool invocation descriptor", message.text)

# In OpenAPIServiceConnector we know how to handle functions tools only
if tool_call["type"] == "function":
function_call = tool_call["function"]
function_payloads.append(
{"arguments": json.loads(function_call["arguments"]), "name": function_call["name"]}
)
return function_payloads

def _authenticate_service(self, openapi_service: "OpenAPI", credentials: Optional[Union[dict, str]] = None):
"""
Authentication with an OpenAPI service.
Expand Down Expand Up @@ -294,4 +396,4 @@ def _invoke_method(self, openapi_service: "OpenAPI", method_invocation_descripto
f"Missing requestBody parameter: '{param_name}' required for the '{name}' operation."
)
# call the underlying service REST API with the parameters
return method_to_call(**method_call_params)
return method_to_call(**method_call_params, raw_response=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Enhanced `OpenAPIServiceConnector` to support and be compatible with the new ChatMessage format.
Loading