Skip to content

Commit

Permalink
refactor(api): improve handling of tools field and cleanup variable…
Browse files Browse the repository at this point in the history
… usage (langgenius#10553)
  • Loading branch information
laipz8200 authored Nov 11, 2024
1 parent b7238ca commit 16b9665
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
9 changes: 7 additions & 2 deletions api/core/tools/entities/api_entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal, Optional

from pydantic import BaseModel
from pydantic import BaseModel, Field, field_validator

from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.common_entities import I18nObject
Expand Down Expand Up @@ -32,9 +32,14 @@ class UserToolProvider(BaseModel):
original_credentials: Optional[dict] = None
is_team_authorization: bool = False
allow_delete: bool = True
tools: list[UserTool] | None = None
tools: list[UserTool] = Field(default_factory=list)
labels: list[str] | None = None

@field_validator("tools", mode="before")
@classmethod
def convert_none_to_empty_list(cls, v):
return v if v is not None else []

def to_dict(self) -> dict:
# -------------
# overwrite tool parameter types for temp fix
Expand Down
15 changes: 7 additions & 8 deletions api/services/tools/api_tools_manage_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def create_api_tool_provider(
provider_name = provider_name.strip()

# check if the provider exists
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
Expand Down Expand Up @@ -201,16 +201,15 @@ def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
return {"schema": schema}

@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
"""
list api tool provider tools
"""
provider_name = provider
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
ApiToolProvider.name == provider_name,
)
.first()
)
Expand Down Expand Up @@ -252,7 +251,7 @@ def update_api_tool_provider(
provider_name = provider_name.strip()

# check if the provider exists
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
Expand Down Expand Up @@ -319,7 +318,7 @@ def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
"""
delete tool provider
"""
provider: ApiToolProvider = (
provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
Expand Down Expand Up @@ -369,7 +368,7 @@ def test_api_tool_preview(
if tool_bundle is None:
raise ValueError(f"invalid tool name {tool_name}")

db_provider: ApiToolProvider = (
db_provider = (
db.session.query(ApiToolProvider)
.filter(
ApiToolProvider.tenant_id == tenant_id,
Expand Down

0 comments on commit 16b9665

Please sign in to comment.