Skip to content

Commit

Permalink
google drive tool: new endpoint for google drive tool auth deletion (#…
Browse files Browse the repository at this point in the history
…654)

* Initial commit - scaffolded new tool auth deletion route

* formatting

* corrected error in function param order, cleaned up logs

* PR Feedback: moved tool auth deletion to base auth implementation class, moved tool_id to path parameter

* Add new endpoint to postman collection

* PR feedback: log_and_raise_exception to logger util, return empty DeleteToolAuth response

* PR Feedback: error_and_raise_http_exception to logger class

* raise bare exception to preserve stacktrace

* added crud tests for tool_auth
  • Loading branch information
chantelle-cohere authored Aug 15, 2024
1 parent 14c77f2 commit b5f2991
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 1 deletion.
20 changes: 20 additions & 0 deletions docs/postman/Toolkit.postman_collection.json
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,26 @@
}
]
},
{
"name": "Tool Auth",
"item": [
{
"name": "Delete Tool Auth",
"request": {
"auth": {
"type": "bearer",
"bearer": {
"token": "{{auth_token}}"
}
},
"method": "DELETE",
"header": [],
"url": "http://localhost:8000/v1/tool/auth/{{tool_id}}"
},
"response": []
}
]
},
{
"name": "Health",
"protocolProfileBehavior": {
Expand Down
66 changes: 65 additions & 1 deletion src/backend/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from backend.config.auth import ENABLED_AUTH_STRATEGY_MAPPING
from backend.config.routers import RouterName
from backend.config.settings import Settings
from backend.config.tools import AVAILABLE_TOOLS
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.crud import blacklist as blacklist_crud
from backend.database_models import Blacklist
from backend.database_models.database import DBSessionDep
from backend.schemas.auth import JWTResponse, ListAuthStrategy, Login, Logout
from backend.schemas.context import Context
from backend.schemas.tool_auth import DeleteToolAuth
from backend.services.auth.jwt import JWTService
from backend.services.auth.request_validators import validate_authorization
from backend.services.auth.utils import (
Expand Down Expand Up @@ -317,3 +318,66 @@ def log_and_redirect_err(error_message: str):
response = RedirectResponse(redirect_uri)

return response


@router.delete("/tool/auth/{tool_id}")
async def delete_tool_auth(
tool_id: str,
request: Request,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> DeleteToolAuth:
"""
Endpoint to delete Tool Authentication.
If completed, the corresponding ToolAuth for the requesting user is removed from the DB.
Args:
tool_id (str): Tool ID to be deleted for the user. (eg. google_drive) Should be one of the values listed in the ToolName string enum class.
request (Request): current Request object.
session (DBSessionDep): Database session.
ctx (Context): Context object.
Returns:
DeleteToolAuth: Empty response.
Raises:
HTTPException: If there was an error deleting the tool auth.
"""

logger = ctx.get_logger()

user_id = ctx.get_user_id()
tool_id = tool_id.lower()

if user_id is None or user_id == "" or user_id == "default":
logger.error_and_raise_http_exception(event="User ID not found.")

if tool_id not in [tool_name.value for tool_name in ToolName]:
logger.error_and_raise_http_exception(
event="tool_id must be present in the path of the request and must be a member of the ToolName string enum class.",
)

tool = AVAILABLE_TOOLS.get(tool_id)

if tool is None:
logger.error_and_raise_http_exception(
event=f"Tool {tool_id} is not available in AVAILABLE_TOOLS."
)

if tool.auth_implementation is None:
logger.error_and_raise_http_exception(
event=f"Tool {tool.name} does not have an auth_implementation required for Tool Auth Deletion.",
)

try:
tool_auth_service = tool.auth_implementation()
is_deleted = tool_auth_service.delete_tool_auth(session, user_id)

if not is_deleted:
logger.error_and_raise_http_exception(event="Error deleting Tool Auth.")

except Exception as e:
logger.error_and_raise_http_exception(event=str(e))

return DeleteToolAuth()
4 changes: 4 additions & 0 deletions src/backend/schemas/tool_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@ class UpdateToolAuth(BaseModel):
class Config:
from_attributes = True
use_enum_values = True


class DeleteToolAuth(BaseModel):
pass
3 changes: 3 additions & 0 deletions src/backend/services/logger/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def info(self, **kwargs: Any) -> Any: ...
@abstractmethod
def error(self, **kwargs: Any) -> Any: ...

@abstractmethod
def error_and_raise_http_exception(self, **kwargs: Any) -> Any: ...

@abstractmethod
def warning(self, **kwargs: Any) -> Any: ...

Expand Down
9 changes: 9 additions & 0 deletions src/backend/services/logger/strategies/structured_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict

import structlog
from fastapi import HTTPException, status

from backend.services.logger.strategies.base import BaseLogger

Expand Down Expand Up @@ -84,6 +85,14 @@ def info(self, **kwargs):
def error(self, **kwargs):
self.logger.error(**kwargs)

@log_context
def error_and_raise_http_exception(self, **kwargs):
self.logger.error(**kwargs)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"{kwargs}",
)

@log_context
def warning(self, **kwargs):
self.logger.warning(**kwargs)
Expand Down
52 changes: 52 additions & 0 deletions src/backend/tests/unit/crud/test_tool_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from datetime import datetime

import pytest

from backend.config.tools import ToolName
from backend.crud import tool_auth as tool_auth_crud
from backend.database_models.tool_auth import ToolAuth
from backend.tests.unit.factories import get_factory


def test_create_tool_auth(session, user):

tool_auth_data = ToolAuth(
user_id=user.id,
tool_id=ToolName.Google_Drive,
token_type="Bearer",
encrypted_access_token=bytes(b"foobar"),
encrypted_refresh_token=bytes(b"foobar"),
expires_at=datetime.strptime("12/31/2124 00:00:00", "%m/%d/%Y %H:%M:%S"),
created_at=datetime.strptime("01/01/2000 00:00:00", "%m/%d/%Y %H:%M:%S"),
updated_at=datetime.strptime("01/01/2010 00:00:00", "%m/%d/%Y %H:%M:%S"),
)

tool_auth = tool_auth_crud.create_tool_auth(session, tool_auth_data)

assert tool_auth.user_id == tool_auth_data.user_id
assert tool_auth.tool_id == tool_auth_data.tool_id
assert tool_auth.token_type == tool_auth_data.token_type
assert tool_auth.encrypted_access_token == tool_auth_data.encrypted_access_token
assert tool_auth.encrypted_refresh_token == tool_auth_data.encrypted_refresh_token
assert tool_auth.expires_at == tool_auth_data.expires_at
assert tool_auth.id == tool_auth_data.id
assert tool_auth.created_at == tool_auth_data.created_at
assert tool_auth.updated_at == tool_auth_data.updated_at


def test_delete_tool_auth_by_tool_id(session, user):
tool_auth = get_factory("ToolAuth", session).create(
user_id=user.id,
tool_id=ToolName.Google_Drive,
token_type="Bearer",
encrypted_access_token=bytes(b"foobar"),
encrypted_refresh_token=bytes(b"foobar"),
expires_at=datetime.strptime("12/31/2124 00:00:00", "%m/%d/%Y %H:%M:%S"),
created_at=datetime.strptime("01/01/2000 00:00:00", "%m/%d/%Y %H:%M:%S"),
updated_at=datetime.strptime("01/01/2010 00:00:00", "%m/%d/%Y %H:%M:%S"),
)

tool_auth_crud.delete_tool_auth(session, user.id, tool_auth.tool_id)

tool_auth = tool_auth_crud.get_tool_auth(session, tool_auth.tool_id, user.id)
assert tool_auth is None
2 changes: 2 additions & 0 deletions src/backend/tests/unit/factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SnapshotFactory,
SnapshotLinkFactory,
)
from backend.tests.unit.factories.tool_auth import ToolAuthFactory
from backend.tests.unit.factories.tool_call import ToolCallFactory
from backend.tests.unit.factories.user import UserFactory

Expand All @@ -40,6 +41,7 @@
"Agent": AgentFactory,
"Organization": OrganizationFactory,
"ToolCall": ToolCallFactory,
"ToolAuth": ToolAuthFactory,
"Snapshot": SnapshotFactory,
"SnapshotLink": SnapshotLinkFactory,
"SnapshotAccess": SnapshotAccessFactory,
Expand Down
22 changes: 22 additions & 0 deletions src/backend/tests/unit/factories/tool_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import datetime

import factory

from backend.config.tools import ToolName
from backend.database_models.tool_auth import ToolAuth

from .base import BaseFactory


class ToolAuthFactory(BaseFactory):
class Meta:
model = ToolAuth

user_id = factory.Faker("uuid4")
tool_id = ToolName.Google_Drive
token_type = "Bearer"
encrypted_access_token = bytes(b"foobar")
encrypted_refresh_token = bytes(b"foobar")
expires_at = datetime.strptime("12/31/2124 00:00:00", "%m/%d/%Y %H:%M:%S")
created_at = datetime.strptime("01/01/2000 00:00:00", "%m/%d/%Y %H:%M:%S")
updated_at = datetime.strptime("01/01/2010 00:00:00", "%m/%d/%Y %H:%M:%S")
10 changes: 10 additions & 0 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def retrieve_auth_token(
def get_token(self, user_id: str, session: DBSessionDep) -> Optional[str]:
return None

def delete_tool_auth(self, session: DBSessionDep, user_id: str) -> bool:
try:
tool_auth_crud.delete_tool_auth(session, user_id, self.TOOL_ID)
return True
except Exception as e:
logger.error(
event=f"BaseToolAuthentication: Error while deleting Tool Auth: {str(e)}"
)
raise


class ToolAuthenticationCacheMixin:
def insert_tool_auth_cache(self, user_id: str, tool_id: str) -> str:
Expand Down

0 comments on commit b5f2991

Please sign in to comment.