Skip to content

Commit

Permalink
Add additional custom tooling configuration (#2426)
Browse files Browse the repository at this point in the history
* add custom headers

* add tool seeding

* squash

* tmep

* validated

* rm

* update typing

* update alembic

* update import name

* reformat
  • Loading branch information
pablonyx authored Sep 19, 2024
1 parent 824a737 commit f598081
Show file tree
Hide file tree
Showing 12 changed files with 242 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: 35e6853a51d5
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "35e6853a51d5"
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)


def downgrade() -> None:
op.drop_column("tool", "custom_headers")
7 changes: 5 additions & 2 deletions backend/danswer/chat/process_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
Expand Down Expand Up @@ -605,12 +607,13 @@ def stream_chat_message_objects(
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
),
)

Expand Down
4 changes: 3 additions & 1 deletion backend/danswer/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,9 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)

custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
Expand Down
8 changes: 8 additions & 0 deletions backend/danswer/db/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlalchemy.orm import Session

from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.logger import setup_logger

logger = setup_logger()
Expand All @@ -25,6 +26,7 @@ def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
Expand All @@ -33,6 +35,9 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,
)
db_session.add(new_tool)
Expand All @@ -45,6 +50,7 @@ def update_tool(
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
Expand All @@ -60,6 +66,8 @@ def update_tool(
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
db_session.commit()

return tool
Expand Down
16 changes: 4 additions & 12 deletions backend/danswer/server/features/tool/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from danswer.db.tools import get_tool_by_id
from danswer.db.tools import get_tools
from danswer.db.tools import update_tool
from danswer.server.features.tool.models import CustomToolCreate
from danswer.server.features.tool.models import CustomToolUpdate
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
Expand All @@ -24,18 +26,6 @@
admin_router = APIRouter(prefix="/admin/tool")


class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]


class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None


def _validate_tool_definition(definition: dict[str, Any]) -> None:
try:
validate_openapi_schema(definition)
Expand All @@ -54,6 +44,7 @@ def create_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)
Expand All @@ -74,6 +65,7 @@ def update_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)
Expand Down
21 changes: 21 additions & 0 deletions backend/danswer/server/features/tool/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ToolSnapshot(BaseModel):
definition: dict[str, Any] | None
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None

@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
Expand All @@ -22,4 +23,24 @@ def from_model(cls, tool: Tool) -> "ToolSnapshot":
definition=tool.openapi_schema,
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
)


class Header(BaseModel):
key: str
value: str


class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None


class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None
21 changes: 16 additions & 5 deletions backend/danswer/tools/custom/custom_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,19 @@ def __init__(
self,
method_spec: MethodSpec,
base_url: str,
custom_headers: list[dict[str, str]] | None = [],
) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()

self._name = self._method_spec.name
self._description = self._method_spec.summary
self.headers = (
{header["key"]: header["value"] for header in custom_headers}
if custom_headers
else {}
)

@property
def name(self) -> str:
Expand Down Expand Up @@ -161,8 +167,10 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:

url = self._method_spec.build_url(self._base_url, path_params, query_params)
method = self._method_spec.method

response = requests.request(method, url, json=request_body)
# Log request details
response = requests.request(
method, url, json=request_body, headers=self.headers
)

yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
Expand All @@ -175,8 +183,9 @@ def final_result(self, *args: ToolResponse) -> JSON_ro:
return cast(CustomToolCallSummary, args[0].response).tool_result


def build_custom_tools_from_openapi_schema(
def build_custom_tools_from_openapi_schema_and_headers(
openapi_schema: dict[str, Any],
custom_headers: list[dict[str, str]] | None = [],
dynamic_schema_info: DynamicSchemaInfo | None = None,
) -> list[CustomTool]:
if dynamic_schema_info:
Expand All @@ -195,7 +204,9 @@ def build_custom_tools_from_openapi_schema(

url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [CustomTool(method_spec, url) for method_spec in method_specs]
return [
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
]


if __name__ == "__main__":
Expand Down Expand Up @@ -246,7 +257,7 @@ def build_custom_tools_from_openapi_schema(
}
validate_openapi_schema(openapi_schema)

tools = build_custom_tools_from_openapi_schema(
tools = build_custom_tools_from_openapi_schema_and_headers(
openapi_schema, dynamic_schema_info=None
)

Expand Down
6 changes: 4 additions & 2 deletions backend/ee/danswer/server/query_and_chat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from danswer.db.models import User
from danswer.db.persona import get_prompts_by_ids
from danswer.one_shot_answer.models import PersonaConfig
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)


def create_temporary_persona(
Expand Down Expand Up @@ -58,7 +60,7 @@ def create_temporary_persona(
for schema in persona_config.custom_tools_openapi:
tools = cast(
list[Tool],
build_custom_tools_from_openapi_schema(schema),
build_custom_tools_from_openapi_schema_and_headers(schema),
)
persona.tools.extend(tools)

Expand Down
54 changes: 54 additions & 0 deletions backend/ee/danswer/server/seeding.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import json
import os
from typing import List
from typing import Optional

from pydantic import BaseModel
from sqlalchemy.orm import Session

from danswer.db.engine import get_session_context_manager
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import Tool
from danswer.db.persona import upsert_persona
from danswer.search.enums import RecencyBiasSetting
from danswer.server.features.persona.models import CreatePersonaRequest
Expand All @@ -25,6 +29,16 @@
from ee.danswer.server.enterprise_settings.store import upload_logo


class CustomToolSeed(BaseModel):
name: str
description: str
definition_path: str
custom_headers: Optional[List[dict]] = None
display_name: Optional[str] = None
in_code_tool_id: Optional[str] = None
user_id: Optional[str] = None


logger = setup_logger()

_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
Expand All @@ -39,6 +53,7 @@ class SeedConfiguration(BaseModel):
enterprise_settings: EnterpriseSettings | None = None
# Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference
analytics_script_path: str | None = None
custom_tools: List[CustomToolSeed] | None = None


def _parse_env() -> SeedConfiguration | None:
Expand All @@ -49,6 +64,43 @@ def _parse_env() -> SeedConfiguration | None:
return seed_config


def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None:
if tools:
logger.notice("Seeding Custom Tools")
for tool in tools:
try:
logger.debug(f"Attempting to seed tool: {tool.name}")
logger.debug(f"Reading definition from: {tool.definition_path}")
with open(tool.definition_path, "r") as file:
file_content = file.read()
if not file_content.strip():
raise ValueError("File is empty")
openapi_schema = json.loads(file_content)
db_tool = Tool(
name=tool.name,
description=tool.description,
openapi_schema=openapi_schema,
custom_headers=tool.custom_headers,
display_name=tool.display_name,
in_code_tool_id=tool.in_code_tool_id,
user_id=tool.user_id,
)
db_session.add(db_tool)
logger.debug(f"Successfully added tool: {tool.name}")
except FileNotFoundError:
logger.error(
f"Definition file not found for tool {tool.name}: {tool.definition_path}"
)
except json.JSONDecodeError as e:
logger.error(
f"Invalid JSON in definition file for tool {tool.name}: {str(e)}"
)
except Exception as e:
logger.error(f"Failed to seed tool {tool.name}: {str(e)}")
db_session.commit()
logger.notice(f"Successfully seeded {len(tools)} Custom Tools")


def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
Expand Down Expand Up @@ -146,6 +198,8 @@ def seed_db() -> None:
_seed_personas(db_session, seed_config.personas)
if seed_config.settings is not None:
_seed_settings(seed_config.settings)
if seed_config.custom_tools is not None:
_seed_custom_tools(db_session, seed_config.custom_tools)

_seed_logo(db_session, seed_config.seeded_logo_path)
_seed_enterprise_settings(seed_config)
Expand Down
Loading

0 comments on commit f598081

Please sign in to comment.