Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert api calls and integrations to tools #833

Draft
wants to merge 9 commits into
base: dev
Choose a base branch
from
21 changes: 10 additions & 11 deletions agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Callable
from typing import Any, Callable

from anthropic import AsyncAnthropic # Import AsyncAnthropic client
from anthropic.types.beta.beta_message import BetaMessage
Expand All @@ -8,6 +8,7 @@
from langchain_core.tools.convert import tool as tool_decorator
from litellm import ChatCompletionMessageToolCall, Function, Message
from litellm.types.utils import Choices, ModelResponse
from pydantic import BaseModel
from temporalio import activity
from temporalio.exceptions import ApplicationError

Expand All @@ -19,7 +20,7 @@
from ...common.storage_handler import auto_blob_store
from ...common.utils.template import render_template
from ...env import anthropic_api_key, debug
from ..utils import get_handler_with_filtered_params
from ..utils import get_handler_with_filtered_params, get_integration_arguments
from .base_evaluate import base_evaluate

COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22"
Expand Down Expand Up @@ -70,13 +71,11 @@ def format_tool(tool: Tool) -> dict:

formatted["function"]["parameters"] = json_schema

# # FIXME: Implement integration tools
# elif tool.type == "integration":
# raise NotImplementedError("Integration tools are not supported")
elif tool.type == "integration" and tool.integration:
formatted["function"]["parameters"] = get_integration_arguments(tool)

# # FIXME: Implement API call tools
# elif tool.type == "api_call":
# raise NotImplementedError("API call tools are not supported")
elif tool.type == "api_call" and tool.api_call:
formatted["function"]["parameters"] = tool.api_call.schema_

return formatted

Expand Down Expand Up @@ -146,7 +145,9 @@ async def prompt_step(context: StepContext) -> StepOutcome:

# Get passed settings
passed_settings: dict = context.current_step.model_dump(
exclude=excluded_keys, exclude_unset=True
# TODO: Should we exclude unset?
exclude=excluded_keys,
exclude_unset=True,
)
passed_settings.update(passed_settings.pop("settings", {}))

Expand Down Expand Up @@ -251,8 +252,6 @@ async def prompt_step(context: StepContext) -> StepOutcome:
)

else:
# FIXME: hardcoded tool to a None value as the tool calls are not implemented yet
formatted_tools = None
# Use litellm for other models
completion_data: dict = {
"model": agent_model,
Expand Down
199 changes: 198 additions & 1 deletion agents-api/agents_api/activities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,43 @@
import statistics
import string
import time
import types
import urllib.parse
from typing import Any, Callable, ParamSpec, TypeVar
from typing import (
Annotated,
Any,
Callable,
Literal,
ParamSpec,
TypeVar,
get_args,
get_origin,
)

import re2
import zoneinfo
from beartype import beartype
from pydantic import BaseModel
from simpleeval import EvalWithCompoundTypes, SimpleEval

from ..autogen.openapi_model import SystemDef
from ..autogen.Tools import (
BraveSearchArguments,
BrowserbaseCompleteSessionArguments,
BrowserbaseContextArguments,
BrowserbaseCreateSessionArguments,
BrowserbaseExtensionArguments,
BrowserbaseGetSessionArguments,
BrowserbaseGetSessionConnectUrlArguments,
BrowserbaseGetSessionLiveUrlsArguments,
BrowserbaseListSessionsArguments,
EmailArguments,
RemoteBrowserArguments,
SpiderFetchArguments,
Tool,
WeatherGetArguments,
WikipediaSearchArguments,
)
from ..common.utils import yaml

T = TypeVar("T")
Expand Down Expand Up @@ -56,6 +84,101 @@
"match_regex": lambda pattern, string: bool(re2.fullmatch(pattern, string)),
}

_args_desc_map = {
BraveSearchArguments: {
"query": "The search query for searching with Brave",
},
EmailArguments: {
"to": "The email address to send the email to",
"from_": "The email address to send the email from",
"subject": "The subject of the email",
"body": "The body of the email",
},
SpiderFetchArguments: {
"url": "The URL to fetch data from",
"mode": "The type of crawler to use",
"params": "Additional parameters for the Spider API",
},
WikipediaSearchArguments: {
"query": "The search query string",
"load_max_docs": "Maximum number of documents to load",
},
WeatherGetArguments: {
"location": "The location for which to fetch weather data",
},
BrowserbaseContextArguments: {
"project_id": "The Project ID. Can be found in Settings.",
},
BrowserbaseExtensionArguments: {
"repository_name": "The GitHub repository name.",
"ref": "Ref to install from a branch or tag.",
},
BrowserbaseListSessionsArguments: {
"status": "The status of the sessions to list (Available options: RUNNING, ERROR, TIMED_OUT, COMPLETED)",
},
BrowserbaseCreateSessionArguments: {
"project_id": "The Project ID. Can be found in Settings.",
"extension_id": "The installed Extension ID. See Install Extension from GitHub.",
"browser_settings": "Browser settings",
"timeout": "Duration in seconds after which the session will automatically end. Defaults to the Project's defaultTimeout.",
"keep_alive": "Set to true to keep the session alive even after disconnections. This is available on the Startup plan only.",
"proxies": "Proxy configuration. Can be true for default proxy, or an array of proxy configurations.",
},
BrowserbaseGetSessionArguments: {
"id": "Session ID",
},
BrowserbaseCompleteSessionArguments: {
"id": "Session ID",
"status": "Session status",
},
BrowserbaseGetSessionLiveUrlsArguments: {
"id": "Session ID",
},
BrowserbaseGetSessionConnectUrlArguments: {
"id": "Session ID",
},
RemoteBrowserArguments: {
"connect_url": "The connection URL for the remote browser",
"action": "The action to perform",
"text": "The text",
"coordinate": "The coordinate to move the mouse to",
},
}

_providers_map = {
"brave": BraveSearchArguments,
"email": EmailArguments,
"spider": SpiderFetchArguments,
"wikipedia": WikipediaSearchArguments,
"weather": WeatherGetArguments,
"browserbase": {
"create_context": BrowserbaseContextArguments,
"install_extension_from_github": BrowserbaseExtensionArguments,
"list_sessions": BrowserbaseListSessionsArguments,
"create_session": BrowserbaseCreateSessionArguments,
"get_session": BrowserbaseGetSessionArguments,
"complete_session": BrowserbaseCompleteSessionArguments,
"get_live_urls": BrowserbaseGetSessionLiveUrlsArguments,
"get_connect_url": BrowserbaseGetSessionConnectUrlArguments,
},
"remote_browser": RemoteBrowserArguments,
}


_arg_types_map = {
BrowserbaseCreateSessionArguments: {
"proxies": {
"type": "boolean | array",
},
},
BrowserbaseListSessionsArguments: {
"status": {
"type": "string",
"enum": "RUNNING,ERROR,TIMED_OUT,COMPLETED",
},
},
}


class stdlib_re:
fullmatch = re2.fullmatch
Expand Down Expand Up @@ -378,3 +501,77 @@ def get_handler(system: SystemDef) -> Callable:
raise NotImplementedError(
f"System call not implemented for {system.resource}.{system.operation}"
)


def _annotation_to_type(
annotation: type, args_model: type[BaseModel], fld_name: str
) -> dict[str, str]:
type_, enum = None, None
if get_origin(annotation) is Literal:
type_ = "string"
enum = ",".join(annotation.__args__)
elif annotation is str:
type_ = "string"
elif annotation in (int, float):
type_ = "number"
elif annotation is list:
type_ = "array"
elif annotation is bool:
type_ = "boolean"
whiterabbit1983 marked this conversation as resolved.
Show resolved Hide resolved
elif annotation == type(None):
type_ = "null"
elif get_origin(annotation) is types.UnionType:
args = [arg for arg in get_args(annotation) if arg is not types.NoneType]
creatorrr marked this conversation as resolved.
Show resolved Hide resolved
if len(args):
return _annotation_to_type(args[0], args_model, fld_name)
else:
type_ = "null"
elif annotation is dict:
type_ = "object"
else:
type_ = _arg_types_map.get(args_model, {fld_name: {"type": "object"}}).get(
fld_name, {"type": "object"}
)["type"]
enum = _arg_types_map.get(args_model, {}).get(fld_name, {}).get("enum")

result = {
"type": type_,
}
if enum is not None:
result.update({"enum": enum})

return result


def get_integration_arguments(tool: Tool):
properties = {
"type": "object",
"properties": {},
"required": [],
}

integration_args: type[BaseModel] | dict[str, type[BaseModel]] | None = (
_providers_map.get(tool.integration.provider)
)

if integration_args is None:
return properties

if isinstance(integration_args, dict):
integration_args: type[BaseModel] | None = integration_args.get(
tool.integration.method
)

if integration_args is None:
return properties

for fld_name, fld_annotation in integration_args.model_fields.items():
tp = _annotation_to_type(fld_annotation.annotation, integration_args, fld_name)
tp["description"] = _args_desc_map.get(integration_args, fld_name).get(
fld_name, fld_name
)
properties["properties"][fld_name] = tp
if fld_annotation.is_required():
properties["required"].append(fld_name)

return properties
6 changes: 3 additions & 3 deletions agents-api/agents_api/workflows/task_execution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ async def run(
]["type"] == "integration":
workflow.logger.debug("Prompt step: Received INTEGRATION tool call")

# FIXME: Implement integration tool calls
# TODO: Implement integration tool calls
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@whiterabbit1983 we need to implement these 3 blocks as well, refer to the ToolStep(...) for how to do them

# See: MANUAL TOOL CALL INTEGRATION (below)
raise NotImplementedError("Integration tool calls not yet supported")

Expand All @@ -452,7 +452,7 @@ async def run(
]["type"] == "api_call":
workflow.logger.debug("Prompt step: Received API_CALL tool call")

# FIXME: Implement API_CALL tool calls
# TODO: Implement API_CALL tool calls
# See: MANUAL TOOL CALL API_CALL (below)
raise NotImplementedError("API_CALL tool calls not yet supported")

Expand All @@ -467,7 +467,7 @@ async def run(
]["type"] == "system":
workflow.logger.debug("Prompt step: Received SYSTEM tool call")

# FIXME: Implement SYSTEM tool calls
# TODO: Implement SYSTEM tool calls
# See: MANUAL TOOL CALL SYSTEM (below)
raise NotImplementedError("SYSTEM tool calls not yet supported")

Expand Down
Loading