Skip to content

Commit

Permalink
feat: Sandboxing for tool execution (#2040)
Browse files Browse the repository at this point in the history
Co-authored-by: Caren Thomas <carenthomas@Jeffs-MacBook-Pro-2.local>
Co-authored-by: Caren Thomas <carenthomas@jeffs-mbp-2.lan>
Co-authored-by: Caren Thomas <carenthomas@Jeffs-MBP-2.hsd1.ca.comcast.net>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
  • Loading branch information
5 people authored Nov 22, 2024
1 parent 476541f commit ae083fc
Show file tree
Hide file tree
Showing 39 changed files with 2,843 additions and 862 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/docker-integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
pipx install poetry==1.8.2
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_client_legacy.py
- name: Print docker logs if tests fail
if: failure()
Expand Down
13 changes: 7 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
run-core-unit-tests:
Expand All @@ -21,14 +21,15 @@ jobs:
fail-fast: false
matrix:
test_suite:
- "test_local_client.py"
- "test_client.py"
- "test_local_client.py"
- "test_client_legacy.py"
- "test_server.py"
- "test_managers.py"
- "test_tools.py"
- "test_o1_agent.py"
- "test_tool_rule_solver.py"
- "test_agent_tool_graph.py"
- "test_tool_execution_sandbox.py"
- "test_utils.py"
services:
qdrant:
Expand Down Expand Up @@ -58,7 +59,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand Down Expand Up @@ -111,7 +112,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand All @@ -132,4 +133,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_utils.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
poetry run pytest -s -vv -k "not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
5 changes: 0 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1018,8 +1018,3 @@ pgdata/
letta/.pytest_cache/
memgpy/pytest.ini
**/**/pytest_cache


# local sandbox venvs
letta/services/tool_sandbox_env/*
tests/test_tool_sandbox/*
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Create sandbox config and sandbox env var tables
Revision ID: f81ceea2c08d
Revises: c85a3d07c028
Create Date: 2024-11-14 17:51:27.263561
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "f81ceea2c08d"
down_revision: Union[str, None] = "f7507eab4bb9"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"sandbox_configs",
sa.Column("id", sa.String(), nullable=False),
sa.Column("type", sa.Enum("E2B", "LOCAL", name="sandboxtype"), nullable=False),
sa.Column("config", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("type", "organization_id", name="uix_type_organization"),
)
op.create_table(
"sandbox_environment_variables",
sa.Column("id", sa.String(), nullable=False),
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("sandbox_config_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.ForeignKeyConstraint(
["sandbox_config_id"],
["sandbox_configs.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("key", "sandbox_config_id", name="uix_key_sandbox_config"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("sandbox_environment_variables")
op.drop_table("sandbox_configs")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion examples/docs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


# define a function with a docstring
def roll_d20(self) -> str:
def roll_d20() -> str:
"""
Simulate the roll of a 20-sided die (d20).
Expand Down
10 changes: 5 additions & 5 deletions examples/tool_rule_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
"""Contrived tools for this test case"""


def first_secret_word(self: "Agent"):
def first_secret_word():
"""
Call this to retrieve the first secret word, which you will need for the second_secret_word function.
"""
return "v0iq020i0g"


def second_secret_word(self: "Agent", prev_secret_word: str):
def second_secret_word(prev_secret_word: str):
"""
Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error.
Expand All @@ -51,7 +51,7 @@ def second_secret_word(self: "Agent", prev_secret_word: str):
return "4rwp2b4gxq"


def third_secret_word(self: "Agent", prev_secret_word: str):
def third_secret_word(prev_secret_word: str):
"""
Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error.
Expand All @@ -64,7 +64,7 @@ def third_secret_word(self: "Agent", prev_secret_word: str):
return "hj2hwibbqm"


def fourth_secret_word(self: "Agent", prev_secret_word: str):
def fourth_secret_word(prev_secret_word: str):
"""
Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error.
Expand All @@ -77,7 +77,7 @@ def fourth_secret_word(self: "Agent", prev_secret_word: str):
return "banana"


def auto_error(self: "Agent"):
def auto_error():
"""
If you call this function, it will throw an error automatically.
"""
Expand Down
25 changes: 23 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from letta.agent_store.storage import StorageConnector
from letta.constants import (
BASE_TOOLS,
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
Expand Down Expand Up @@ -49,6 +50,7 @@
from letta.schemas.usage import LettaUsageStatistics
from letta.services.block_manager import BlockManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
Expand Down Expand Up @@ -725,9 +727,27 @@ def _handle_ai_response(
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

function_args["self"] = self # need to attach self to arg since it's dynamically linked
# TODO: This needs to be rethought, how do we allow functions that modify agent state/db?
# TODO: There should probably be two types of tools: stateless/stateful

if function_name in BASE_TOOLS:
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
agent_state=self.agent_state
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
# update agent state
if self.agent_state != updated_agent_state and updated_agent_state is not None:
self.agent_state = updated_agent_state
self.memory = self.agent_state.memory # TODO: don't duplicate

# rebuild memory
self.rebuild_memory()

function_response = function_to_call(**function_args)
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
# with certain functions we rely on the paging mechanism to handle overflow
truncate = False
Expand All @@ -747,6 +767,7 @@ def _handle_ai_response(
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
Expand Down
Loading

0 comments on commit ae083fc

Please sign in to comment.