From 53e466f56520e813c6e5014ca57e6efa9270f317 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 00:16:56 -0700 Subject: [PATCH 01/11] Added a tool server that exposes all tools as endpoints --- README.md | 11 ++++ pyproject.toml | 8 +++ src/aviary/main.py | 34 +++++++++++ src/aviary/tools/base.py | 2 +- src/aviary/tools/server.py | 112 +++++++++++++++++++++++++++++++++++++ tests/test_tools.py | 47 ++++++++++++++++ uv.lock | 56 ++++++++++++++++++- 7 files changed, 266 insertions(+), 4 deletions(-) create mode 100644 src/aviary/main.py create mode 100644 src/aviary/tools/server.py diff --git a/README.md b/README.md index e112ce95..2c8ce5cc 100644 --- a/README.md +++ b/README.md @@ -217,6 +217,17 @@ def export_frame(self) -> Frame: ) ``` +### View Environment Tools + +If an environment can be instantiated without anything other than a task (i.e., it implements `from_task`), you can start a server to view its tools: + +```sh +pip install fhaviary[server] +aviary tools [env name] +``` + +This will start a server that allows you to view the tools and call them, viewing the descriptions/types and output that an agent would see when using the server. + ## Environments ### GSM8k Environment diff --git a/pyproject.toml b/pyproject.toml index d1f72067..5c3ea88c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,11 @@ image = [ llm = [ "litellm", ] +server = [ + "click", + "fastapi", + "uvicorn", +] typing = [ "boto3-stubs[s3]", "numpy", @@ -47,6 +52,9 @@ xml = [ "dicttoxml", ] +[project.scripts] +aviary = "aviary:main" + [project.urls] issues = "https://github.com/Future-House/aviary/issues" repository = "https://github.com/Future-House/aviary" diff --git a/src/aviary/main.py b/src/aviary/main.py new file mode 100644 index 00000000..82fe672a --- /dev/null +++ b/src/aviary/main.py @@ -0,0 +1,34 @@ +import os + +import click +import uvicorn + +from aviary.env import Environment +from aviary.tools.server import make_tool_server + + +@click.group() +def cli(): + pass + + +@cli.command() +@click.argument("env") +@click.option("--host", default="localhost") +@click.option("--port", default=8000) +@click.option("--token", default="secret") +def tools(env, host, port, token): + if not os.environ.get("AUTH_TOKEN"): + os.environ["AUTH_TOKEN"] = token + # use empty task to trigger + # an empty task/no problem + env = Environment.from_name(env, task="") + app = make_tool_server(env.tools) + click.echo( + f"View tools at http://{host}:{port}/docs and log in with token {token!r}" + ) + uvicorn.run(app, host=host, port=port) + + +if __name__ == "__main__": + cli() diff --git a/src/aviary/tools/base.py b/src/aviary/tools/base.py index d7355b02..a5629114 100644 --- a/src/aviary/tools/base.py +++ b/src/aviary/tools/base.py @@ -39,7 +39,7 @@ dict: "object", None: "null", } - +reverse_type_map = {v: k for k, v in type_map.items()} # A string to denote an invalid tool. It can be used to indicate # an attempt to use a non-existent tool, missing/invalid parameters, diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py new file mode 100644 index 00000000..782d67c5 --- /dev/null +++ b/src/aviary/tools/server.py @@ -0,0 +1,112 @@ +import os +import secrets +import sys +from inspect import signature + +from pydantic import Field, create_model + +from aviary.tools.base import Tool, reverse_type_map +from aviary.utils import is_coroutine_callable + + +def make_tool_server(tools: list[Tool], name: str | None = None): # noqa: C901 + try: + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + except ModuleNotFoundError as exc: + raise ImportError( + "Please install aviary with the 'server' extra like so:" + " `pip install aviary[server]`." + ) from exc + + auth_scheme = HTTPBearer() + + async def validate_token( + credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008 + ) -> str: + # NOTE: don't use os.environ.get() to avoid possible empty string matches, and + # to have clearer server failures if the AUTH_TOKEN env var isn't present + if not secrets.compare_digest( + credentials.credentials, os.environ["AUTH_TOKEN"] + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect bearer token", + headers={"WWW-Authenticate": "Bearer"}, + ) + return credentials.credentials + + def create_request_model_from_tool(tool): + fields = {} + for pname, info in tool.info.parameters.properties.items(): + if pname == "type": + continue + # we just assume it exists + ptype = reverse_type_map[info["type"]] + + # decipher optional description, optional default, and type + if pname in tool.info.parameters.required: + if "description" in info: + fields[pname] = (ptype, Field(description=info["description"])) + else: + fields[pname] = (ptype, ...) + elif "description" in info: + fields[pname] = ( + ptype | None, + Field(description=info["description"], default=None), + ) + else: + fields[pname] = (ptype | None, None) + + return create_model(f"{tool.info.name.capitalize()}Params", **fields) # type: ignore[call-overload] + + web_app = FastAPI( + title=name or "Aviary Tool Server", + description="API Server for Aviary Environment Tools", + dependencies=[Depends(validate_token)], + ) + + # filter only for tools that are executable + tools = [tool for tool in tools if hasattr(tool, "_tool_fn")] + + # Dynamically create routes for each tool + for tool in tools: + tool_name = tool.info.name + tool_description = tool.info.description + RequestModel = create_request_model_from_tool(tool) + return_type = signature(tool._tool_fn).return_annotation + + # ensure the this will be in fast api scope + # close your eyes PR reviewers + # also fuck your IDE tools + RequestModel.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") + + def create_tool_handler(tool_fn, RequestModel, tool_description): + async def _tool_handler( + data: RequestModel, # type: ignore[valid-type] + ): + try: + # Call the tool function with the provided arguments + if is_coroutine_callable(tool_fn): + return await tool_fn(**data.model_dump()) # type: ignore[attr-defined] + return tool_fn(**data.model_dump()) # type: ignore[attr-defined] + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) from e + + _tool_handler.__doc__ = tool_description + return _tool_handler + + tool_handler = create_tool_handler( + tool._tool_fn, RequestModel, tool_description + ) + + # Add a POST route for the tool + web_app.post( + f"/{tool_name}", + summary=tool_name, + name=tool_name, + response_model=return_type, + description=tool_description, + )(tool_handler) + + return web_app diff --git a/tests/test_tools.py b/tests/test_tools.py index f1fad3f0..7778ae8c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,10 +1,12 @@ import json +import os import pickle from collections.abc import Callable, Sequence from enum import IntEnum, auto from typing import Any import pytest +from fastapi.testclient import TestClient from pydantic import BaseModel, Field from pytest_subtests import SubTests @@ -17,6 +19,7 @@ ToolRequestMessage, argref_by_name, ) +from aviary.tools.server import make_tool_server def simple() -> None: @@ -749,3 +752,47 @@ def complex_typed_fn(c: Sequence[int], d: int | str) -> None: with pytest.raises(TypeError): # passing list[str], not list[int] type_checked_fn(c="str_list_arg", d="int_arg", state=s) + + +def test_make_tool_server(): + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + def subtract(a: int, b: int) -> int: + """Subtract two numbers. + + Args: + a: first number + b: second number + """ + return a - b + + tools = [ + Tool.from_function(add, allow_empty_param_descriptions=True), + Tool.from_function(subtract), + ] + server = make_tool_server(tools) + + # make sure there are two endpoints + route_names = [route.name for route in server.routes] + assert "add" in route_names + assert "subtract" in route_names + + # make sure we can call them + client = TestClient(server) + prev_token = os.environ.get("AUTH_TOKEN") + token = "test_make_tool_server" + os.environ["AUTH_TOKEN"] = token + + try: + response = client.post( + "/add", json={"a": 1, "b": 2}, headers={"Authorization": f"Bearer {token}"} + ) + assert response.status_code == 200 + assert response.json() == 3 + finally: + if prev_token is None: + del os.environ["AUTH_TOKEN"] + else: + os.environ["AUTH_TOKEN"] = prev_token diff --git a/uv.lock b/uv.lock index 0ee2d605..8537a82a 100644 --- a/uv.lock +++ b/uv.lock @@ -173,7 +173,7 @@ wheels = [ [[package]] name = "aviary-gsm8k" -version = "0.8.1.dev2+gfb8b2a7" +version = "0.5.1.dev49+gf34802d.d20241022" source = { editable = "packages/gsm8k" } dependencies = [ { name = "datasets" }, @@ -196,7 +196,7 @@ requires-dist = [ [[package]] name = "aviary-hotpotqa" -version = "0.8.1.dev2+gfb8b2a7" +version = "0.5.1.dev49+gf34802d.d20241022" source = { editable = "packages/hotpotqa" } dependencies = [ { name = "beautifulsoup4" }, @@ -573,9 +573,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/fd/afcd0496feca3276f509df3dbd5dae726fcc756f1a08d9e25abe1733f962/executing-2.1.0-py2.py3-none-any.whl", hash = "sha256:8d63781349375b5ebccc3142f4b30350c0cd9c79f921cde38be2be4637e98eaf", size = 25805 }, ] +[[package]] +name = "fastapi" +version = "0.115.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/fa/19e3c7c9b31ac291987c82e959f36f88840bea183fa3dc3bb654669f19c1/fastapi-0.115.2.tar.gz", hash = "sha256:3995739e0b09fa12f984bce8fa9ae197b35d433750d3d312422d846e283697ee", size = 299968 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c9/14/bbe7776356ef01f830f8085ca3ac2aea59c73727b6ffaa757abeb7d2900b/fastapi-0.115.2-py3-none-any.whl", hash = "sha256:61704c71286579cc5a598763905928f24ee98bfcc07aabe84cfefb98812bbc86", size = 94650 }, +] + [[package]] name = "fhaviary" -version = "0.8.1.dev2+gfb8b2a7" +version = "0.5.1.dev49+gf34802d.d20241022" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, @@ -599,6 +613,11 @@ image = [ llm = [ { name = "litellm" }, ] +server = [ + { name = "click" }, + { name = "fastapi" }, + { name = "uvicorn" }, +] typing = [ { name = "boto3-stubs", extra = ["s3"] }, { name = "numpy" }, @@ -613,8 +632,10 @@ dev = [ { name = "aviary-gsm8k", extra = ["typing"] }, { name = "aviary-hotpotqa" }, { name = "boto3-stubs", extra = ["s3"] }, + { name = "click" }, { name = "codeflash" }, { name = "dicttoxml" }, + { name = "fastapi" }, { name = "ipython" }, { name = "litellm" }, { name = "mypy" }, @@ -635,6 +656,7 @@ dev = [ { name = "sqlalchemy", extra = ["aiosqlite"] }, { name = "typeguard" }, { name = "types-pillow" }, + { name = "uvicorn" }, ] [package.metadata] @@ -643,14 +665,17 @@ requires-dist = [ { name = "aviary-hotpotqa", marker = "extra == 'hotpotqa'", editable = "packages/hotpotqa" }, { name = "boto3", marker = "extra == 'cloud'" }, { name = "boto3-stubs", extras = ["s3"], marker = "extra == 'typing'" }, + { name = "click", marker = "extra == 'server'" }, { name = "dicttoxml", marker = "extra == 'xml'" }, { name = "docstring-parser", specifier = ">=0.16" }, + { name = "fastapi", marker = "extra == 'server'" }, { name = "httpx" }, { name = "litellm", marker = "extra == 'llm'" }, { name = "numpy", marker = "extra == 'typing'" }, { name = "pillow", marker = "extra == 'image'" }, { name = "pydantic", specifier = "~=2.0" }, { name = "types-pillow", marker = "extra == 'typing'" }, + { name = "uvicorn", marker = "extra == 'server'" }, ] [package.metadata.requires-dev] @@ -2385,6 +2410,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, ] +[[package]] +name = "starlette" +version = "0.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/cb/244daf0d7be4508099ad5bca3cdfe8b8b5538acd719c5f397f614e569fff/starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35", size = 2573611 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/0f/64baf7a06492e8c12f5c4b49db286787a7255195df496fc21f5fd9eecffa/starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4", size = 73303 }, +] + [[package]] name = "tenacity" version = "9.0.0" @@ -2600,6 +2637,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/cf/8435d5a7159e2a9c83a95896ed596f68cf798005fe107cc655b5c5c14704/urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e", size = 144225 }, ] +[[package]] +name = "uvicorn" +version = "0.32.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/fc/1d785078eefd6945f3e5bab5c076e4230698046231eb0f3747bc5c8fa992/uvicorn-0.32.0.tar.gz", hash = "sha256:f78b36b143c16f54ccdb8190d0a26b5f1901fe5a3c777e1ab29f26391af8551e", size = 77564 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/14/78bd0e95dd2444b6caacbca2b730671d4295ccb628ef58b81bee903629df/uvicorn-0.32.0-py3-none-any.whl", hash = "sha256:60b8f3a5ac027dcd31448f411ced12b5ef452c646f76f02f8cc3f25d8d26fd82", size = 63723 }, +] + [[package]] name = "vcrpy" version = "6.0.2" From 387e92d56f8c8af68b0448267fbb15ced46a557b Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 00:23:49 -0700 Subject: [PATCH 02/11] Fixed README typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2c8ce5cc..ad81e81a 100644 --- a/README.md +++ b/README.md @@ -226,7 +226,7 @@ pip install fhaviary[server] aviary tools [env name] ``` -This will start a server that allows you to view the tools and call them, viewing the descriptions/types and output that an agent would see when using the server. +This will start a server that allows you to view the tools and call them, viewing the descriptions/types and output that an agent would see when using the tools. ## Environments From d1b1361b1cb22c54700c8820cfec104699d615ad Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 01:09:57 -0700 Subject: [PATCH 03/11] More robust way of initializing environments --- src/aviary/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/aviary/main.py b/src/aviary/main.py index 82fe672a..b1d4aeb8 100644 --- a/src/aviary/main.py +++ b/src/aviary/main.py @@ -1,3 +1,4 @@ +import asyncio import os import click @@ -23,7 +24,8 @@ def tools(env, host, port, token): # use empty task to trigger # an empty task/no problem env = Environment.from_name(env, task="") - app = make_tool_server(env.tools) + _, tools = asyncio.run(env.reset()) + app = make_tool_server(tools) click.echo( f"View tools at http://{host}:{port}/docs and log in with token {token!r}" ) From 9833f965f3ef06b482d58ca4382fe236cbbed5cb Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 17:08:05 -0700 Subject: [PATCH 04/11] Added environments to it now too --- src/aviary/__init__.py | 0 src/aviary/main.py | 8 ++- src/aviary/message.py | 4 ++ src/aviary/tools/base.py | 4 ++ src/aviary/tools/server.py | 117 +++++++++++++++++++++++++++++++++++-- tests/test_tools.py | 26 ++++++--- 6 files changed, 143 insertions(+), 16 deletions(-) create mode 100644 src/aviary/__init__.py diff --git a/src/aviary/__init__.py b/src/aviary/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aviary/main.py b/src/aviary/main.py index b1d4aeb8..e302463a 100644 --- a/src/aviary/main.py +++ b/src/aviary/main.py @@ -21,11 +21,13 @@ def cli(): def tools(env, host, port, token): if not os.environ.get("AUTH_TOKEN"): os.environ["AUTH_TOKEN"] = token + # use empty task to trigger # an empty task/no problem - env = Environment.from_name(env, task="") - _, tools = asyncio.run(env.reset()) - app = make_tool_server(tools) + def env_factory(): + return Environment.from_name(env, task="") + + app = asyncio.run(make_tool_server(env_factory)) click.echo( f"View tools at http://{host}:{port}/docs and log in with token {token!r}" ) diff --git a/src/aviary/message.py b/src/aviary/message.py index 0831f23d..97241b27 100644 --- a/src/aviary/message.py +++ b/src/aviary/message.py @@ -158,5 +158,9 @@ def common_retryable_errors_log_filter(cls, record: LogRecord) -> bool: return not all(x in record.msg for x in (cls.__name__, EMPTY_CONTENT_BASE_MSG)) +class EnvStateMessage(Message): + """A message that contains the current state of the environment.""" + + # Define separately so we can filter out this message type EMPTY_CONTENT_BASE_MSG = "No content in message" diff --git a/src/aviary/tools/base.py b/src/aviary/tools/base.py index a5629114..8509d621 100644 --- a/src/aviary/tools/base.py +++ b/src/aviary/tools/base.py @@ -297,10 +297,14 @@ def __init__( super().__init__(**kwargs) # NOTE: this Callable is excluded from serialization self._tool_fn = tool_fn + self._force_pickle_fn = False def __getstate__(self) -> dict[Any, Any]: # Prevent _tool_fn from being pickled, SEE: https://stackoverflow.com/a/2345953 state = super().__getstate__() + # allow forcing pickle, e.g., for cloud pickle sending + if self._force_pickle_fn: + return state state["__dict__"] = state["__dict__"].copy() state["__dict__"].pop("_tool_fn", None) return state diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 782d67c5..3fae3ffb 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -1,7 +1,11 @@ import os import secrets import sys +from collections.abc import Callable from inspect import signature +from pathlib import Path +from typing import Any +from uuid import uuid4 from pydantic import Field, create_model @@ -9,8 +13,27 @@ from aviary.utils import is_coroutine_callable -def make_tool_server(tools: list[Tool], name: str | None = None): # noqa: C901 +async def make_tool_server( + environment_factory: Callable, + name: str | None = None, + env_path: Path = Path("/tmp"), +): + """Create a FastAPI server for the provided environment. + + This function exposes one endpoint per tool and endpoints to create/view/delete environments. + In contrast to other environment servers that expose an action endpoint, this one exposes all tools individually. + + This is only for debugging tools and not intended as a strategy for working with environments. + Most environments have side-effects from using tools that occur in the step function. This + bypasses that and allows you to call tools directly. + + Args: + environment_factory: A callable that returns an environment instance. + name: The name of the server. Defaults to Aviary Tool Server. + env_path: The path to the directory to store environments + """ try: + import cloudpickle as pickle from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer except ModuleNotFoundError as exc: @@ -36,13 +59,34 @@ async def validate_token( ) return credentials.credentials + # these seem useful in other contexts, but from what I read + # it is discouraged to save/load so leaving it defined here + def save_environment(environment, tools, environment_id): + # make sure we force all tools to pickle + for tool in tools: + tool._force_pickle_fn = True + with open(env_path / f"{environment_id}.pkl", "wb") as f: + pickle.dump((environment, tools), f) + + def load_environment(environment_id): + if not (env_path / f"{environment_id}.pkl").exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Environment {environment_id} not found", + ) + with open(env_path / f"{environment_id}.pkl", "rb") as f: + return pickle.load(f) + + def make_environment_id(): + return f"env{str(uuid4())[:8].replace('-', '')}" + def create_request_model_from_tool(tool): fields = {} for pname, info in tool.info.parameters.properties.items(): if pname == "type": continue # we just assume it exists - ptype = reverse_type_map[info["type"]] + ptype = reverse_type_map[info["type"]] if "type" in info else Any # decipher optional description, optional default, and type if pname in tool.info.parameters.required: @@ -66,6 +110,10 @@ def create_request_model_from_tool(tool): dependencies=[Depends(validate_token)], ) + # make a starting environment to save tools + env = environment_factory() + _, tools = await env.reset() + # filter only for tools that are executable tools = [tool for tool in tools if hasattr(tool, "_tool_fn")] @@ -81,23 +129,41 @@ def create_request_model_from_tool(tool): # also fuck your IDE tools RequestModel.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") - def create_tool_handler(tool_fn, RequestModel, tool_description): + def create_tool_handler(tool_name, RequestModel, tool_description): async def _tool_handler( data: RequestModel, # type: ignore[valid-type] + environment_id: str = "", ): + if environment_id: + env, env_tools = load_environment(environment_id) + else: + env = environment_factory() + _, env_tools = await env.reset() + environment_id = make_environment_id() + + # ok now find the tool_fn to call it with + # that came from the env I just loaded + tool_fn = next( + tool._tool_fn for tool in env_tools if tool.info.name == tool_name + ) + try: # Call the tool function with the provided arguments if is_coroutine_callable(tool_fn): - return await tool_fn(**data.model_dump()) # type: ignore[attr-defined] - return tool_fn(**data.model_dump()) # type: ignore[attr-defined] + result = await tool_fn(**data.model_dump()) # type: ignore[attr-defined] + else: + result = tool_fn(**data.model_dump()) # type: ignore[attr-defined] except Exception as e: raise HTTPException(status_code=500, detail=str(e)) from e + save_environment(env, env_tools, environment_id) + return {"result": result, "environment_id": environment_id} + _tool_handler.__doc__ = tool_description return _tool_handler tool_handler = create_tool_handler( - tool._tool_fn, RequestModel, tool_description + tool.info.name, RequestModel, tool_description ) # Add a POST route for the tool @@ -109,4 +175,43 @@ async def _tool_handler( description=tool_description, )(tool_handler) + # Add environment endpoints + @web_app.get( + "/env/create", + summary="Create Environment", + description="Create a new environment", + ) + async def create_environment_endpoint(): + env = environment_factory() + _, tools = await env.reset() + environment_id = make_environment_id() + save_environment(env, tools, environment_id) + return environment_id + + @web_app.get( + "/env/delete/{environment_id}", + summary="Delete Environment", + description="Delete an environment", + ) + async def delete_environment_endpoint(environment_id: str): + if (env_path / f"{environment_id}.pkl").exists(): + (env_path / f"{environment_id}.pkl").unlink() + return environment_id + + @web_app.get( + "/env/view/{environment_id}", + summary="View Environment", + description="View an environment", + ) + async def view_environment_endpoint(environment_id: str): + if not (env_path / f"{environment_id}.pkl").exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Environment {environment_id} not found", + ) + with open(env_path / f"{environment_id}.pkl", "rb") as f: + env, tools = pickle.load(f) + + return env.state + return web_app diff --git a/tests/test_tools.py b/tests/test_tools.py index 7778ae8c..749f40c8 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -10,7 +10,8 @@ from pydantic import BaseModel, Field from pytest_subtests import SubTests -from aviary.env import DummyEnv +from aviary.env import DummyEnv, Environment, Frame +from aviary.message import Message from aviary.tools import ( INVALID_TOOL_NAME, FunctionInfo, @@ -754,7 +755,8 @@ def complex_typed_fn(c: Sequence[int], d: int | str) -> None: type_checked_fn(c="str_list_arg", d="int_arg", state=s) -def test_make_tool_server(): +@pytest.mark.asyncio +async def test_make_tool_server(): def add(a: int, b: int) -> int: """Add two numbers.""" return a + b @@ -768,11 +770,21 @@ def subtract(a: int, b: int) -> int: """ return a - b - tools = [ - Tool.from_function(add, allow_empty_param_descriptions=True), - Tool.from_function(subtract), - ] - server = make_tool_server(tools) + class MyEnv(Environment): + async def reset(self) -> tuple[list[Message], list[Tool]]: + tools = [ + Tool.from_function(add, allow_empty_param_descriptions=True), + Tool.from_function(subtract), + ] + return [], tools + + async def step(self): + pass + + async def export_frame(self): + pass + + server = await make_tool_server(MyEnv) # make sure there are two endpoints route_names = [route.name for route in server.routes] From a1ab0a170d66ba86ef04b4be8c51b04eff11d453 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 17:09:57 -0700 Subject: [PATCH 05/11] Apply suggestions from code review Co-authored-by: James Braza --- src/aviary/tools/server.py | 11 ++++------- tests/test_tools.py | 12 ++---------- 2 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 3fae3ffb..becf1f20 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -80,7 +80,7 @@ def load_environment(environment_id): def make_environment_id(): return f"env{str(uuid4())[:8].replace('-', '')}" - def create_request_model_from_tool(tool): + def create_request_model_from_tool(tool: Tool) -> BaseModel: fields = {} for pname, info in tool.info.parameters.properties.items(): if pname == "type": @@ -114,11 +114,8 @@ def create_request_model_from_tool(tool): env = environment_factory() _, tools = await env.reset() - # filter only for tools that are executable - tools = [tool for tool in tools if hasattr(tool, "_tool_fn")] - # Dynamically create routes for each tool - for tool in tools: + for executable_tool in (t for t in tools if hasattr(t, "_tool_fn")): tool_name = tool.info.name tool_description = tool.info.description RequestModel = create_request_model_from_tool(tool) @@ -154,7 +151,7 @@ async def _tool_handler( else: result = tool_fn(**data.model_dump()) # type: ignore[attr-defined] except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e save_environment(env, env_tools, environment_id) return {"result": result, "environment_id": environment_id} @@ -166,7 +163,7 @@ async def _tool_handler( tool.info.name, RequestModel, tool_description ) - # Add a POST route for the tool + # Add a POST route so we can invoke the tool function web_app.post( f"/{tool_name}", summary=tool_name, diff --git a/tests/test_tools.py b/tests/test_tools.py index 749f40c8..23ffc9d3 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -793,18 +793,10 @@ async def export_frame(self): # make sure we can call them client = TestClient(server) - prev_token = os.environ.get("AUTH_TOKEN") - token = "test_make_tool_server" - os.environ["AUTH_TOKEN"] = token - - try: + token = "stub" + with patch.dict(os.environ, {"AUTH_TOKEN": token}): response = client.post( "/add", json={"a": 1, "b": 2}, headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200 assert response.json() == 3 - finally: - if prev_token is None: - del os.environ["AUTH_TOKEN"] - else: - os.environ["AUTH_TOKEN"] = prev_token From dbf00895a8bb5283b7cca19daa944d963a128517 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Tue, 22 Oct 2024 17:13:44 -0700 Subject: [PATCH 06/11] PR comments --- src/aviary/main.py | 5 ++++- src/aviary/tools/server.py | 11 ++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/aviary/main.py b/src/aviary/main.py index e302463a..050fe06c 100644 --- a/src/aviary/main.py +++ b/src/aviary/main.py @@ -8,6 +8,9 @@ from aviary.tools.server import make_tool_server +# this enables sub commands for the CLI +# so we can call `aviary tools` to start the tool server +# rather than aviary-tools or something @click.group() def cli(): pass @@ -18,7 +21,7 @@ def cli(): @click.option("--host", default="localhost") @click.option("--port", default=8000) @click.option("--token", default="secret") -def tools(env, host, port, token): +def tools(env: str, host: str, port: int, token: str): if not os.environ.get("AUTH_TOKEN"): os.environ["AUTH_TOKEN"] = token diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 3fae3ffb..a9c7d8ab 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -9,14 +9,14 @@ from pydantic import Field, create_model -from aviary.tools.base import Tool, reverse_type_map +from aviary.tools.base import reverse_type_map from aviary.utils import is_coroutine_callable -async def make_tool_server( +async def make_tool_server( # noqa: C901, PLR0915 environment_factory: Callable, name: str | None = None, - env_path: Path = Path("/tmp"), + env_path: Path = Path("/tmp"), # noqa: S108 ): """Create a FastAPI server for the provided environment. @@ -125,6 +125,7 @@ def create_request_model_from_tool(tool): return_type = signature(tool._tool_fn).return_annotation # ensure the this will be in fast api scope + # because fastapi will barf on a request model that isn't in scope # close your eyes PR reviewers # also fuck your IDE tools RequestModel.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__") @@ -209,8 +210,8 @@ async def view_environment_endpoint(environment_id: str): status_code=status.HTTP_404_NOT_FOUND, detail=f"Environment {environment_id} not found", ) - with open(env_path / f"{environment_id}.pkl", "rb") as f: - env, tools = pickle.load(f) + with open(env_path / f"{environment_id}.pkl", "rb") as f: # noqa: ASYNC230 + env, _ = pickle.load(f) return env.state From 073182ab264d92ff2dc524b926fedb10c58e9e34 Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 23 Oct 2024 16:06:20 -0700 Subject: [PATCH 07/11] Included environment usage in tool server --- src/aviary/tools/server.py | 35 ++++++++++++++++++----------------- tests/test_tools.py | 14 +++++++++----- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 001b32b9..6a654800 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -2,15 +2,13 @@ import secrets import sys from collections.abc import Callable -from inspect import signature from pathlib import Path from typing import Any from uuid import uuid4 -from pydantic import Field, create_model +from pydantic import BaseModel, Field, create_model -from aviary.tools.base import reverse_type_map -from aviary.utils import is_coroutine_callable +from aviary.tools.base import Tool, ToolCall, ToolRequestMessage, reverse_type_map async def make_tool_server( # noqa: C901, PLR0915 @@ -115,11 +113,10 @@ def create_request_model_from_tool(tool: Tool) -> BaseModel: _, tools = await env.reset() # Dynamically create routes for each tool - for executable_tool in (t for t in tools if hasattr(t, "_tool_fn")): + for tool in (t for t in tools if hasattr(t, "_tool_fn")): tool_name = tool.info.name tool_description = tool.info.description RequestModel = create_request_model_from_tool(tool) - return_type = signature(tool._tool_fn).return_annotation # ensure the this will be in fast api scope # because fastapi will barf on a request model that isn't in scope @@ -141,21 +138,26 @@ async def _tool_handler( # ok now find the tool_fn to call it with # that came from the env I just loaded - tool_fn = next( - tool._tool_fn for tool in env_tools if tool.info.name == tool_name + msg = ToolRequestMessage( + tool_calls=[ToolCall.from_name(tool_name, **data.model_dump())] # type: ignore[attr-defined] ) - try: - # Call the tool function with the provided arguments - if is_coroutine_callable(tool_fn): - result = await tool_fn(**data.model_dump()) # type: ignore[attr-defined] - else: - result = tool_fn(**data.model_dump()) # type: ignore[attr-defined] + result_msgs, done, *_ = await env.step(msg) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) from e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) + ) from e + + if done: + _, env_tools = await env.reset() save_environment(env, env_tools, environment_id) - return {"result": result, "environment_id": environment_id} + return { + "result": "\n\n".join([ + str(msg.content) for msg in result_msgs if msg.content + ]), + "environment_id": environment_id, + } _tool_handler.__doc__ = tool_description return _tool_handler @@ -169,7 +171,6 @@ async def _tool_handler( f"/{tool_name}", summary=tool_name, name=tool_name, - response_model=return_type, description=tool_description, )(tool_handler) diff --git a/tests/test_tools.py b/tests/test_tools.py index 23ffc9d3..8e04681c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -4,13 +4,15 @@ from collections.abc import Callable, Sequence from enum import IntEnum, auto from typing import Any +from unittest.mock import patch import pytest from fastapi.testclient import TestClient from pydantic import BaseModel, Field from pytest_subtests import SubTests +from typeguard import suppress_type_checks -from aviary.env import DummyEnv, Environment, Frame +from aviary.env import DummyEnv, Environment from aviary.message import Message from aviary.tools import ( INVALID_TOOL_NAME, @@ -776,15 +778,17 @@ async def reset(self) -> tuple[list[Message], list[Tool]]: Tool.from_function(add, allow_empty_param_descriptions=True), Tool.from_function(subtract), ] + self.tools = tools return [], tools - async def step(self): - pass + async def step(self, action): + return await self.exec_tool_calls(action), False, 0, 0 async def export_frame(self): pass - server = await make_tool_server(MyEnv) + with suppress_type_checks(): + server = await make_tool_server(MyEnv) # make sure there are two endpoints route_names = [route.name for route in server.routes] @@ -799,4 +803,4 @@ async def export_frame(self): "/add", json={"a": 1, "b": 2}, headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200 - assert response.json() == 3 + assert response.json()["result"] == "3" From 0edb0c012d94f180624ee265592ce6b2c70993ed Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 23 Oct 2024 16:19:19 -0700 Subject: [PATCH 08/11] Cloudpickle --- pyproject.toml | 1 + uv.lock | 14 +++++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5c3ea88c..0c9cae9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ llm = [ ] server = [ "click", + "cloudpickle", "fastapi", "uvicorn", ] diff --git a/uv.lock b/uv.lock index 8537a82a..0d810d6c 100644 --- a/uv.lock +++ b/uv.lock @@ -423,6 +423,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/2e/d53fa4befbf2cfa713304affc7ca780ce4fc1fd8710527771b58311a3229/click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28", size = 97941 }, ] +[[package]] +name = "cloudpickle" +version = "3.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/97/c7/f746cadd08c4c08129215cf1b984b632f9e579fc781301e63da9e85c76c1/cloudpickle-3.1.0.tar.gz", hash = "sha256:81a929b6e3c7335c863c771d673d105f02efdb89dfaba0c90495d1c64796601b", size = 66155 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/41/e1d85ca3cab0b674e277c8c4f678cf66a91cd2cecf93df94353a606fe0db/cloudpickle-3.1.0-py3-none-any.whl", hash = "sha256:fe11acda67f61aaaec473e3afe030feb131d78a43461b718185363384f1ba12e", size = 22021 }, +] + [[package]] name = "codeflash" version = "0.7.1" @@ -589,7 +598,7 @@ wheels = [ [[package]] name = "fhaviary" -version = "0.5.1.dev49+gf34802d.d20241022" +version = "0.5.1.dev59+g8d9a4e4.d20241023" source = { editable = "." } dependencies = [ { name = "docstring-parser" }, @@ -615,6 +624,7 @@ llm = [ ] server = [ { name = "click" }, + { name = "cloudpickle" }, { name = "fastapi" }, { name = "uvicorn" }, ] @@ -633,6 +643,7 @@ dev = [ { name = "aviary-hotpotqa" }, { name = "boto3-stubs", extra = ["s3"] }, { name = "click" }, + { name = "cloudpickle" }, { name = "codeflash" }, { name = "dicttoxml" }, { name = "fastapi" }, @@ -666,6 +677,7 @@ requires-dist = [ { name = "boto3", marker = "extra == 'cloud'" }, { name = "boto3-stubs", extras = ["s3"], marker = "extra == 'typing'" }, { name = "click", marker = "extra == 'server'" }, + { name = "cloudpickle", marker = "extra == 'server'" }, { name = "dicttoxml", marker = "extra == 'xml'" }, { name = "docstring-parser", specifier = ">=0.16" }, { name = "fastapi", marker = "extra == 'server'" }, From 5323fb58b5b24926046f0ca464f2364c4bee686d Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 23 Oct 2024 16:38:01 -0700 Subject: [PATCH 09/11] Removed __init__ --- pyproject.toml | 2 +- src/aviary/__init__.py | 0 src/aviary/main.py | 10 ++++++++-- src/aviary/tools/server.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) delete mode 100644 src/aviary/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 0c9cae9b..d5abf221 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ xml = [ ] [project.scripts] -aviary = "aviary:main" +aviary = "aviary.main:cli" [project.urls] issues = "https://github.com/Future-House/aviary/issues" diff --git a/src/aviary/__init__.py b/src/aviary/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/aviary/main.py b/src/aviary/main.py index 050fe06c..03fc4c0d 100644 --- a/src/aviary/main.py +++ b/src/aviary/main.py @@ -1,8 +1,14 @@ import asyncio import os -import click -import uvicorn +try: + import click + import uvicorn +except ImportError as e: + raise ImportError( + "CLI requires the 'server' extra for 'aviary'. Please:" + " `pip install aviary[server]`." + ) from e from aviary.env import Environment from aviary.tools.server import make_tool_server diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 6a654800..b36e46a5 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -208,7 +208,7 @@ async def view_environment_endpoint(environment_id: str): status_code=status.HTTP_404_NOT_FOUND, detail=f"Environment {environment_id} not found", ) - with open(env_path / f"{environment_id}.pkl", "rb") as f: # noqa: ASYNC230 + with (env_path / f"{environment_id}.pkl").open("rb") as f: env, _ = pickle.load(f) return env.state From bf1f294df802d041172bc04ee9d59c7583a095ce Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 23 Oct 2024 16:38:41 -0700 Subject: [PATCH 10/11] PR comments --- src/aviary/tools/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index b36e46a5..02856faa 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -13,7 +13,7 @@ async def make_tool_server( # noqa: C901, PLR0915 environment_factory: Callable, - name: str | None = None, + name: str = "Aviary Tool Server", env_path: Path = Path("/tmp"), # noqa: S108 ): """Create a FastAPI server for the provided environment. @@ -103,7 +103,7 @@ def create_request_model_from_tool(tool: Tool) -> BaseModel: return create_model(f"{tool.info.name.capitalize()}Params", **fields) # type: ignore[call-overload] web_app = FastAPI( - title=name or "Aviary Tool Server", + title=name, description="API Server for Aviary Environment Tools", dependencies=[Depends(validate_token)], ) From 1c459c9b220f5e7eb533ed70410c1fb498ef597b Mon Sep 17 00:00:00 2001 From: Andrew White Date: Wed, 23 Oct 2024 16:50:07 -0700 Subject: [PATCH 11/11] windows and B008 --- src/aviary/tools/server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/aviary/tools/server.py b/src/aviary/tools/server.py index 02856faa..3725fc65 100644 --- a/src/aviary/tools/server.py +++ b/src/aviary/tools/server.py @@ -1,6 +1,7 @@ import os import secrets import sys +import tempfile from collections.abc import Callable from pathlib import Path from typing import Any @@ -14,7 +15,7 @@ async def make_tool_server( # noqa: C901, PLR0915 environment_factory: Callable, name: str = "Aviary Tool Server", - env_path: Path = Path("/tmp"), # noqa: S108 + env_path: Path | None = None, ): """Create a FastAPI server for the provided environment. @@ -40,6 +41,8 @@ async def make_tool_server( # noqa: C901, PLR0915 " `pip install aviary[server]`." ) from exc + if not env_path: + env_path = Path(tempfile.gettempdir()) auth_scheme = HTTPBearer() async def validate_token(