From 18c62a0c24e848fc4bf64aca2351170a19dfd677 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 20 Sep 2024 16:12:52 -0700 Subject: [PATCH] Add additional custom tooling configuration (#2426) * add custom headers * add tool seeding * squash * tmep * validated * rm * update typing * update alembic * update import name * reformat * alembic --- ...32615f71aeb_add_custom_headers_to_tools.py | 26 +++++ backend/danswer/chat/process_message.py | 7 +- backend/danswer/db/models.py | 4 +- backend/danswer/db/tools.py | 8 ++ backend/danswer/server/features/tool/api.py | 16 +-- .../danswer/server/features/tool/models.py | 21 ++++ backend/danswer/tools/custom/custom_tool.py | 21 +++- .../ee/danswer/server/query_and_chat/utils.py | 6 +- backend/ee/danswer/server/seeding.py | 54 +++++++++++ web/src/app/admin/tools/ToolEditor.tsx | 97 ++++++++++++++++++- web/src/lib/tools/edit.ts | 3 + web/src/lib/tools/interfaces.ts | 4 + 12 files changed, 242 insertions(+), 25 deletions(-) create mode 100644 backend/alembic/versions/f32615f71aeb_add_custom_headers_to_tools.py diff --git a/backend/alembic/versions/f32615f71aeb_add_custom_headers_to_tools.py b/backend/alembic/versions/f32615f71aeb_add_custom_headers_to_tools.py new file mode 100644 index 00000000000..904059e6ee3 --- /dev/null +++ b/backend/alembic/versions/f32615f71aeb_add_custom_headers_to_tools.py @@ -0,0 +1,26 @@ +"""add custom headers to tools + +Revision ID: f32615f71aeb +Revises: bd2921608c3a +Create Date: 2024-09-12 20:26:38.932377 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "f32615f71aeb" +down_revision = "bd2921608c3a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True) + ) + + +def downgrade() -> None: + op.drop_column("tool", "custom_headers") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 2e79e200676..f09ac18f32a 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -73,7 +73,9 @@ from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema +from danswer.tools.custom.custom_tool import ( + build_custom_tools_from_openapi_schema_and_headers, +) from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID from danswer.tools.custom.custom_tool import CustomToolCallSummary from danswer.tools.force import ForceUseTool @@ -607,12 +609,13 @@ def stream_chat_message_objects( if db_tool_model.openapi_schema: tool_dict[db_tool_model.id] = cast( list[Tool], - build_custom_tools_from_openapi_schema( + build_custom_tools_from_openapi_schema_and_headers( db_tool_model.openapi_schema, dynamic_schema_info=DynamicSchemaInfo( chat_session_id=chat_session_id, message_id=user_message.id if user_message else None, ), + custom_headers=db_tool_model.custom_headers, ), ) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index aecb1d9f973..f5be97d1bbe 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1255,7 +1255,9 @@ class Tool(Base): openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) - + custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) # user who created / owns the tool. Will be None for built-in tools. user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py index 1e75b1c4901..248744b5639 100644 --- a/backend/danswer/db/tools.py +++ b/backend/danswer/db/tools.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from danswer.db.models import Tool +from danswer.server.features.tool.models import Header from danswer.utils.logger import setup_logger logger = setup_logger() @@ -25,6 +26,7 @@ def create_tool( name: str, description: str | None, openapi_schema: dict[str, Any] | None, + custom_headers: list[Header] | None, user_id: UUID | None, db_session: Session, ) -> Tool: @@ -33,6 +35,9 @@ def create_tool( description=description, in_code_tool_id=None, openapi_schema=openapi_schema, + custom_headers=[header.dict() for header in custom_headers] + if custom_headers + else [], user_id=user_id, ) db_session.add(new_tool) @@ -45,6 +50,7 @@ def update_tool( name: str | None, description: str | None, openapi_schema: dict[str, Any] | None, + custom_headers: list[Header] | None, user_id: UUID | None, db_session: Session, ) -> Tool: @@ -60,6 +66,8 @@ def update_tool( tool.openapi_schema = openapi_schema if user_id is not None: tool.user_id = user_id + if custom_headers is not None: + tool.custom_headers = [header.dict() for header in custom_headers] db_session.commit() return tool diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 9635a276507..1d441593784 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -15,6 +15,8 @@ from danswer.db.tools import get_tool_by_id from danswer.db.tools import get_tools from danswer.db.tools import update_tool +from danswer.server.features.tool.models import CustomToolCreate +from danswer.server.features.tool.models import CustomToolUpdate from danswer.server.features.tool.models import ToolSnapshot from danswer.tools.custom.openapi_parsing import MethodSpec from danswer.tools.custom.openapi_parsing import openapi_to_method_specs @@ -24,18 +26,6 @@ admin_router = APIRouter(prefix="/admin/tool") -class CustomToolCreate(BaseModel): - name: str - description: str | None = None - definition: dict[str, Any] - - -class CustomToolUpdate(BaseModel): - name: str | None = None - description: str | None = None - definition: dict[str, Any] | None = None - - def _validate_tool_definition(definition: dict[str, Any]) -> None: try: validate_openapi_schema(definition) @@ -54,6 +44,7 @@ def create_custom_tool( name=tool_data.name, description=tool_data.description, openapi_schema=tool_data.definition, + custom_headers=tool_data.custom_headers, user_id=user.id if user else None, db_session=db_session, ) @@ -74,6 +65,7 @@ def update_custom_tool( name=tool_data.name, description=tool_data.description, openapi_schema=tool_data.definition, + custom_headers=tool_data.custom_headers, user_id=user.id if user else None, db_session=db_session, ) diff --git a/backend/danswer/server/features/tool/models.py b/backend/danswer/server/features/tool/models.py index 0c1da965d4f..bf3e4d159b6 100644 --- a/backend/danswer/server/features/tool/models.py +++ b/backend/danswer/server/features/tool/models.py @@ -12,6 +12,7 @@ class ToolSnapshot(BaseModel): definition: dict[str, Any] | None display_name: str in_code_tool_id: str | None + custom_headers: list[Any] | None @classmethod def from_model(cls, tool: Tool) -> "ToolSnapshot": @@ -22,4 +23,24 @@ def from_model(cls, tool: Tool) -> "ToolSnapshot": definition=tool.openapi_schema, display_name=tool.display_name or tool.name, in_code_tool_id=tool.in_code_tool_id, + custom_headers=tool.custom_headers, ) + + +class Header(BaseModel): + key: str + value: str + + +class CustomToolCreate(BaseModel): + name: str + description: str | None = None + definition: dict[str, Any] + custom_headers: list[Header] | None = None + + +class CustomToolUpdate(BaseModel): + name: str | None = None + description: str | None = None + definition: dict[str, Any] | None = None + custom_headers: list[Header] | None = None diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py index 0272b4ad607..3d36d7bb055 100644 --- a/backend/danswer/tools/custom/custom_tool.py +++ b/backend/danswer/tools/custom/custom_tool.py @@ -46,6 +46,7 @@ def __init__( self, method_spec: MethodSpec, base_url: str, + custom_headers: list[dict[str, str]] | None = [], ) -> None: self._base_url = base_url self._method_spec = method_spec @@ -53,6 +54,11 @@ def __init__( self._name = self._method_spec.name self._description = self._method_spec.summary + self.headers = ( + {header["key"]: header["value"] for header in custom_headers} + if custom_headers + else {} + ) @property def name(self) -> str: @@ -161,8 +167,10 @@ def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: url = self._method_spec.build_url(self._base_url, path_params, query_params) method = self._method_spec.method - - response = requests.request(method, url, json=request_body) + # Log request details + response = requests.request( + method, url, json=request_body, headers=self.headers + ) yield ToolResponse( id=CUSTOM_TOOL_RESPONSE_ID, @@ -175,8 +183,9 @@ def final_result(self, *args: ToolResponse) -> JSON_ro: return cast(CustomToolCallSummary, args[0].response).tool_result -def build_custom_tools_from_openapi_schema( +def build_custom_tools_from_openapi_schema_and_headers( openapi_schema: dict[str, Any], + custom_headers: list[dict[str, str]] | None = [], dynamic_schema_info: DynamicSchemaInfo | None = None, ) -> list[CustomTool]: if dynamic_schema_info: @@ -195,7 +204,9 @@ def build_custom_tools_from_openapi_schema( url = openapi_to_url(openapi_schema) method_specs = openapi_to_method_specs(openapi_schema) - return [CustomTool(method_spec, url) for method_spec in method_specs] + return [ + CustomTool(method_spec, url, custom_headers) for method_spec in method_specs + ] if __name__ == "__main__": @@ -246,7 +257,7 @@ def build_custom_tools_from_openapi_schema( } validate_openapi_schema(openapi_schema) - tools = build_custom_tools_from_openapi_schema( + tools = build_custom_tools_from_openapi_schema_and_headers( openapi_schema, dynamic_schema_info=None ) diff --git a/backend/ee/danswer/server/query_and_chat/utils.py b/backend/ee/danswer/server/query_and_chat/utils.py index beb970fd1b8..a2f7253517a 100644 --- a/backend/ee/danswer/server/query_and_chat/utils.py +++ b/backend/ee/danswer/server/query_and_chat/utils.py @@ -12,7 +12,9 @@ from danswer.db.models import User from danswer.db.persona import get_prompts_by_ids from danswer.one_shot_answer.models import PersonaConfig -from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema +from danswer.tools.custom.custom_tool import ( + build_custom_tools_from_openapi_schema_and_headers, +) def create_temporary_persona( @@ -58,7 +60,7 @@ def create_temporary_persona( for schema in persona_config.custom_tools_openapi: tools = cast( list[Tool], - build_custom_tools_from_openapi_schema(schema), + build_custom_tools_from_openapi_schema_and_headers(schema), ) persona.tools.extend(tools) diff --git a/backend/ee/danswer/server/seeding.py b/backend/ee/danswer/server/seeding.py index b161f057030..007aa352cae 100644 --- a/backend/ee/danswer/server/seeding.py +++ b/backend/ee/danswer/server/seeding.py @@ -1,4 +1,7 @@ +import json import os +from typing import List +from typing import Optional from pydantic import BaseModel from sqlalchemy.orm import Session @@ -6,6 +9,7 @@ from danswer.db.engine import get_session_context_manager from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider +from danswer.db.models import Tool from danswer.db.persona import upsert_persona from danswer.search.enums import RecencyBiasSetting from danswer.server.features.persona.models import CreatePersonaRequest @@ -25,6 +29,16 @@ from ee.danswer.server.enterprise_settings.store import upload_logo +class CustomToolSeed(BaseModel): + name: str + description: str + definition_path: str + custom_headers: Optional[List[dict]] = None + display_name: Optional[str] = None + in_code_tool_id: Optional[str] = None + user_id: Optional[str] = None + + logger = setup_logger() _SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" @@ -39,6 +53,7 @@ class SeedConfiguration(BaseModel): enterprise_settings: EnterpriseSettings | None = None # Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference analytics_script_path: str | None = None + custom_tools: List[CustomToolSeed] | None = None def _parse_env() -> SeedConfiguration | None: @@ -49,6 +64,43 @@ def _parse_env() -> SeedConfiguration | None: return seed_config +def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None: + if tools: + logger.notice("Seeding Custom Tools") + for tool in tools: + try: + logger.debug(f"Attempting to seed tool: {tool.name}") + logger.debug(f"Reading definition from: {tool.definition_path}") + with open(tool.definition_path, "r") as file: + file_content = file.read() + if not file_content.strip(): + raise ValueError("File is empty") + openapi_schema = json.loads(file_content) + db_tool = Tool( + name=tool.name, + description=tool.description, + openapi_schema=openapi_schema, + custom_headers=tool.custom_headers, + display_name=tool.display_name, + in_code_tool_id=tool.in_code_tool_id, + user_id=tool.user_id, + ) + db_session.add(db_tool) + logger.debug(f"Successfully added tool: {tool.name}") + except FileNotFoundError: + logger.error( + f"Definition file not found for tool {tool.name}: {tool.definition_path}" + ) + except json.JSONDecodeError as e: + logger.error( + f"Invalid JSON in definition file for tool {tool.name}: {str(e)}" + ) + except Exception as e: + logger.error(f"Failed to seed tool {tool.name}: {str(e)}") + db_session.commit() + logger.notice(f"Successfully seeded {len(tools)} Custom Tools") + + def _seed_llms( db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] ) -> None: @@ -147,6 +199,8 @@ def seed_db() -> None: _seed_personas(db_session, seed_config.personas) if seed_config.settings is not None: _seed_settings(seed_config.settings) + if seed_config.custom_tools is not None: + _seed_custom_tools(db_session, seed_config.custom_tools) _seed_logo(db_session, seed_config.seeded_logo_path) _seed_enterprise_settings(seed_config) diff --git a/web/src/app/admin/tools/ToolEditor.tsx b/web/src/app/admin/tools/ToolEditor.tsx index ff8bd86c40e..b4df98f8623 100644 --- a/web/src/app/admin/tools/ToolEditor.tsx +++ b/web/src/app/admin/tools/ToolEditor.tsx @@ -2,7 +2,14 @@ import { useState, useEffect, useCallback } from "react"; import { useRouter } from "next/navigation"; -import { Formik, Form, Field, ErrorMessage } from "formik"; +import { + Formik, + Form, + Field, + ErrorMessage, + FieldArray, + ArrayHelpers, +} from "formik"; import * as Yup from "yup"; import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces"; import { TextFormField } from "@/components/admin/connectors/Field"; @@ -14,6 +21,7 @@ import { } from "@/lib/tools/edit"; import { usePopup } from "@/components/admin/connectors/Popup"; import debounce from "lodash/debounce"; +import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle"; import Link from "next/link"; function parseJsonWithTrailingCommas(jsonString: string) { @@ -55,6 +63,7 @@ function ToolForm({ }) { const [definitionError, setDefinitionError] = definitionErrorState; const [methodSpecs, setMethodSpecs] = methodSpecsState; + const [showAdvancedOptions, setShowAdvancedOptions] = useState(false); const debouncedValidateDefinition = useCallback( debounce(async (definition: string) => { @@ -137,7 +146,7 @@ function ToolForm({
{methodSpecs && methodSpecs.length > 0 && ( -
+

Available methods

@@ -192,7 +201,75 @@ function ToolForm({ )} + + {showAdvancedOptions && ( +
+

+ Custom Headers +

+

+ Specify custom headers for each request to this tool's API. +

+ ( +
+ {values.customHeaders && values.customHeaders.length > 0 && ( +
+ {values.customHeaders.map( + ( + header: { key: string; value: string }, + index: number + ) => ( +
+ + + +
+ ) + )} +
+ )} + + +
+ )} + /> +
+ )} + +