Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a tool server that exposes all tools as endpoints #79

Merged
merged 13 commits into from
Oct 24, 2024
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,17 @@ def export_frame(self) -> Frame:
)
```

### View Environment Tools

If an environment can be instantiated without anything other than a task (i.e., it implements `from_task`), you can start a server to view its tools:

```sh
pip install fhaviary[server]
aviary tools [env name]
```

This will start a server that allows you to view the tools and call them, viewing the descriptions/types and output that an agent would see when using the tools.

## Environments

### GSM8k Environment
Expand Down
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ image = [
llm = [
"litellm",
]
server = [
"click",
"cloudpickle",
"fastapi",
"uvicorn",
]
typing = [
"boto3-stubs[s3]",
"numpy",
Expand All @@ -47,6 +53,9 @@ xml = [
"dicttoxml",
]

[project.scripts]
aviary = "aviary.main:cli"

[project.urls]
issues = "https://github.com/Future-House/aviary/issues"
repository = "https://github.com/Future-House/aviary"
Expand Down
47 changes: 47 additions & 0 deletions src/aviary/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
import os

try:
import click
import uvicorn
except ImportError as e:
raise ImportError(
"CLI requires the 'server' extra for 'aviary'. Please:"
" `pip install aviary[server]`."
) from e

from aviary.env import Environment
from aviary.tools.server import make_tool_server


# this enables sub commands for the CLI
# so we can call `aviary tools` to start the tool server
# rather than aviary-tools or something
@click.group()
def cli():
pass
jamesbraza marked this conversation as resolved.
Show resolved Hide resolved


@cli.command()
@click.argument("env")
@click.option("--host", default="localhost")
@click.option("--port", default=8000)
@click.option("--token", default="secret")
def tools(env: str, host: str, port: int, token: str):
if not os.environ.get("AUTH_TOKEN"):
os.environ["AUTH_TOKEN"] = token

# use empty task to trigger
# an empty task/no problem
def env_factory():
return Environment.from_name(env, task="")

app = asyncio.run(make_tool_server(env_factory))
click.echo(
f"View tools at http://{host}:{port}/docs and log in with token {token!r}"
)
uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
cli()
4 changes: 4 additions & 0 deletions src/aviary/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,9 @@ def common_retryable_errors_log_filter(cls, record: LogRecord) -> bool:
return not all(x in record.msg for x in (cls.__name__, EMPTY_CONTENT_BASE_MSG))


class EnvStateMessage(Message):
"""A message that contains the current state of the environment."""
Comment on lines +161 to +162
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used at all?

And if it's used, do you mind documenting where the current state gets housed (e.g. is it JSON in the content)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to tree search - I had to promote this code to prevent dependence on aviary-internal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, I think when @sidnarayanan promotes the rest of his code it will be documented then.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fwiw, tree search is not going to be in aviary, it will be in LDP. Should we put it in LDP with tree search?

I guess I think it's confusing to have this message subclass here without usage or docs, and no specialized behaviors

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EnvStateMessage should be in aviary - it can be used by environments to indicate that a particular message in the returned observation represents the state of the env. I don't see any harm in keeping it here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to work this out in another PR - basically @sidnarayanan needs this message type in cloning, cloning should not depend on ldp or aviary-internal.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we at least put Sid's comment into the docstring for EnvStateMessage then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beyond "A message that contains the current state of the environment." ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like this:

Suggested change
class EnvStateMessage(Message):
"""A message that contains the current state of the environment."""
class EnvStateMessage(Message):
"""
A message variant whose contents are known to contain the environment state.
For example, the contents can be a JSON serialization of the environment state.
"""



# Define separately so we can filter out this message type
EMPTY_CONTENT_BASE_MSG = "No content in message"
6 changes: 5 additions & 1 deletion src/aviary/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
dict: "object",
None: "null",
}

reverse_type_map = {v: k for k, v in type_map.items()}

# A string to denote an invalid tool. It can be used to indicate
# an attempt to use a non-existent tool, missing/invalid parameters,
Expand Down Expand Up @@ -297,10 +297,14 @@ def __init__(
super().__init__(**kwargs)
# NOTE: this Callable is excluded from serialization
self._tool_fn = tool_fn
self._force_pickle_fn = False

def __getstate__(self) -> dict[Any, Any]:
# Prevent _tool_fn from being pickled, SEE: https://stackoverflow.com/a/2345953
state = super().__getstate__()
# allow forcing pickle, e.g., for cloud pickle sending
if self._force_pickle_fn:
return state
state["__dict__"] = state["__dict__"].copy()
state["__dict__"].pop("_tool_fn", None)
return state
Expand Down
219 changes: 219 additions & 0 deletions src/aviary/tools/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import os
import secrets
import sys
import tempfile
from collections.abc import Callable
from pathlib import Path
from typing import Any
from uuid import uuid4

from pydantic import BaseModel, Field, create_model

from aviary.tools.base import Tool, ToolCall, ToolRequestMessage, reverse_type_map


async def make_tool_server( # noqa: C901, PLR0915
environment_factory: Callable,
name: str = "Aviary Tool Server",
env_path: Path | None = None,
):
"""Create a FastAPI server for the provided environment.

This function exposes one endpoint per tool and endpoints to create/view/delete environments.
In contrast to other environment servers that expose an action endpoint, this one exposes all tools individually.

This is only for debugging tools and not intended as a strategy for working with environments.
Most environments have side-effects from using tools that occur in the step function. This
bypasses that and allows you to call tools directly.

Args:
environment_factory: A callable that returns an environment instance.
name: The name of the server. Defaults to Aviary Tool Server.
env_path: The path to the directory to store environments
"""
try:
import cloudpickle as pickle
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
except ModuleNotFoundError as exc:
raise ImportError(
"Please install aviary with the 'server' extra like so:"
" `pip install aviary[server]`."
) from exc

if not env_path:
env_path = Path(tempfile.gettempdir())
auth_scheme = HTTPBearer()

async def validate_token(
credentials: HTTPAuthorizationCredentials = Depends(auth_scheme), # noqa: B008
) -> str:
# NOTE: don't use os.environ.get() to avoid possible empty string matches, and
# to have clearer server failures if the AUTH_TOKEN env var isn't present
if not secrets.compare_digest(
credentials.credentials, os.environ["AUTH_TOKEN"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect bearer token",
headers={"WWW-Authenticate": "Bearer"},
)
return credentials.credentials

# these seem useful in other contexts, but from what I read
# it is discouraged to save/load so leaving it defined here
def save_environment(environment, tools, environment_id):
# make sure we force all tools to pickle
for tool in tools:
tool._force_pickle_fn = True
with open(env_path / f"{environment_id}.pkl", "wb") as f:
pickle.dump((environment, tools), f)

def load_environment(environment_id):
if not (env_path / f"{environment_id}.pkl").exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Environment {environment_id} not found",
)
with open(env_path / f"{environment_id}.pkl", "rb") as f:
return pickle.load(f)

def make_environment_id():
return f"env{str(uuid4())[:8].replace('-', '')}"

def create_request_model_from_tool(tool: Tool) -> BaseModel:
fields = {}
for pname, info in tool.info.parameters.properties.items():
if pname == "type":
continue
# we just assume it exists
ptype = reverse_type_map[info["type"]] if "type" in info else Any

# decipher optional description, optional default, and type
if pname in tool.info.parameters.required:
if "description" in info:
fields[pname] = (ptype, Field(description=info["description"]))
else:
fields[pname] = (ptype, ...)
elif "description" in info:
fields[pname] = (
ptype | None,
Field(description=info["description"], default=None),
)
else:
fields[pname] = (ptype | None, None)

return create_model(f"{tool.info.name.capitalize()}Params", **fields) # type: ignore[call-overload]

web_app = FastAPI(
title=name,
description="API Server for Aviary Environment Tools",
dependencies=[Depends(validate_token)],
)

# make a starting environment to save tools
env = environment_factory()
_, tools = await env.reset()

# Dynamically create routes for each tool
for tool in (t for t in tools if hasattr(t, "_tool_fn")):
tool_name = tool.info.name
tool_description = tool.info.description
RequestModel = create_request_model_from_tool(tool)

# ensure the this will be in fast api scope
whitead marked this conversation as resolved.
Show resolved Hide resolved
# because fastapi will barf on a request model that isn't in scope
# close your eyes PR reviewers
# also fuck your IDE tools
RequestModel.__module__ = sys._getframe(1).f_globals.get("__name__", "__main__")

def create_tool_handler(tool_name, RequestModel, tool_description):
async def _tool_handler(
data: RequestModel, # type: ignore[valid-type]
environment_id: str = "",
):
if environment_id:
env, env_tools = load_environment(environment_id)
else:
env = environment_factory()
_, env_tools = await env.reset()
environment_id = make_environment_id()

# ok now find the tool_fn to call it with
# that came from the env I just loaded
msg = ToolRequestMessage(
tool_calls=[ToolCall.from_name(tool_name, **data.model_dump())] # type: ignore[attr-defined]
)
try:
result_msgs, done, *_ = await env.step(msg)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
) from e

if done:
_, env_tools = await env.reset()

save_environment(env, env_tools, environment_id)
return {
"result": "\n\n".join([
str(msg.content) for msg in result_msgs if msg.content
]),
"environment_id": environment_id,
}

_tool_handler.__doc__ = tool_description
return _tool_handler

tool_handler = create_tool_handler(
tool.info.name, RequestModel, tool_description
)

# Add a POST route so we can invoke the tool function
web_app.post(
f"/{tool_name}",
summary=tool_name,
name=tool_name,
description=tool_description,
)(tool_handler)

# Add environment endpoints
@web_app.get(
"/env/create",
summary="Create Environment",
description="Create a new environment",
)
async def create_environment_endpoint():
env = environment_factory()
_, tools = await env.reset()
environment_id = make_environment_id()
save_environment(env, tools, environment_id)
return environment_id

@web_app.get(
"/env/delete/{environment_id}",
summary="Delete Environment",
description="Delete an environment",
)
async def delete_environment_endpoint(environment_id: str):
if (env_path / f"{environment_id}.pkl").exists():
(env_path / f"{environment_id}.pkl").unlink()
return environment_id

@web_app.get(
"/env/view/{environment_id}",
summary="View Environment",
description="View an environment",
)
async def view_environment_endpoint(environment_id: str):
if not (env_path / f"{environment_id}.pkl").exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Environment {environment_id} not found",
)
with (env_path / f"{environment_id}.pkl").open("rb") as f:
env, _ = pickle.load(f)

return env.state

return web_app
Loading
Loading