diff --git a/agenta-backend/agenta_backend/routers/permissions_router.py b/agenta-backend/agenta_backend/routers/permissions_router.py index 01410f2b21..9e7532b975 100644 --- a/agenta-backend/agenta_backend/routers/permissions_router.py +++ b/agenta-backend/agenta_backend/routers/permissions_router.py @@ -29,8 +29,8 @@ def __init__( class Deny(HTTPException): def __init__(self) -> None: super().__init__( - status_code=401, - detail="Unauthorized", + status_code=403, + detail="Forbidden", ) diff --git a/agenta-backend/agenta_backend/services/llm_apps_service.py b/agenta-backend/agenta_backend/services/llm_apps_service.py index 341500acc3..51d708d05f 100644 --- a/agenta-backend/agenta_backend/services/llm_apps_service.py +++ b/agenta-backend/agenta_backend/services/llm_apps_service.py @@ -3,7 +3,7 @@ import asyncio import traceback import aiohttp -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from agenta_backend.models.shared_models import InvokationResult, Result, Error @@ -296,7 +296,17 @@ async def batch_invoke( list_of_app_outputs: List[ InvokationResult ] = [] # Outputs after running all batches - openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json") + + headers = None + if isCloudEE(): + secret_token = await sign_secret_token(user_id, project_id, None) + + headers = {"Authorization": f"Secret {secret_token}"} + + openapi_parameters = await get_parameters_from_openapi( + uri + "/openapi.json", + headers, + ) async def run_batch(start_idx: int): tasks = [] @@ -336,7 +346,10 @@ async def run_batch(start_idx: int): return list_of_app_outputs -async def get_parameters_from_openapi(uri: str) -> List[Dict]: +async def get_parameters_from_openapi( + uri: str, + headers: Optional[Dict[str, str]], +) -> List[Dict]: """ Parse the OpenAI schema of an LLM app to return list of parameters that it takes with their type as determined by the x-parameter Args: @@ -351,7 +364,7 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]: """ - schema = await _get_openai_json_from_uri(uri) + schema = await _get_openai_json_from_uri(uri, headers) try: body_schema_name = ( @@ -381,9 +394,12 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]: return parameters -async def _get_openai_json_from_uri(uri): +async def _get_openai_json_from_uri( + uri: str, + headers: Optional[Dict[str, str]], +): async with aiohttp.ClientSession() as client: - resp = await client.get(uri, timeout=5) + resp = await client.get(uri, headers=headers, timeout=5) resp_text = await resp.text() json_data = json.loads(resp_text) return json_data diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index 3b81fc712d..d2fe0d014c 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -6,6 +6,10 @@ from celery import shared_task, states from agenta_backend.utils.common import isCloudEE + +if isCloudEE(): + from agenta_backend.cloud.services.auth_helper import sign_secret_token + from agenta_backend.services import ( evaluators_service, llm_apps_service, @@ -143,8 +147,20 @@ def evaluate( ) # 4. Evaluate the app outputs + secret_token = None + headers = None + if isCloudEE(): + secret_token = loop.run_until_complete( + sign_secret_token(user_id, project_id, None) + ) + if secret_token: + headers = {"Authorization": f"Secret {secret_token}"} + openapi_parameters = loop.run_until_complete( - llm_apps_service.get_parameters_from_openapi(uri + "/openapi.json") + llm_apps_service.get_parameters_from_openapi( + uri + "/openapi.json", + headers, + ), ) for data_point, app_output in zip(testset_db.csvdata, app_outputs): # type: ignore diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 4416cf8077..4e9f7ba017 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -37,27 +37,21 @@ import agenta as ag -app = FastAPI() - -origins = [ - "*", -] -app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], +AGENTA_USE_CORS = str(environ.get("AGENTA_USE_CORS", "true")).lower() in ( + "true", + "1", + "t", ) -_MIDDLEWARES = True +app = FastAPI() +log.setLevel("DEBUG") -app.include_router(router, prefix="") +_MIDDLEWARES = True -log.setLevel("DEBUG") +app.include_router(router, prefix="") class PathValidator(BaseModel): @@ -137,6 +131,15 @@ def __init__( resource_type="application", ) + if AGENTA_USE_CORS: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, + ) + _MIDDLEWARES = False except: # pylint: disable=bare-except diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 7a017ef220..6663a02dba 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -21,6 +21,12 @@ 15 * 60, # 15 minutes ) +AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in ( + "true", + "1", + "t", +) + AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) ).lower() in ("true", "1", "t") @@ -89,9 +95,11 @@ async def dispatch( sort_keys=True, ) - cached_policy = cache.get(_hash) + policy = None + if AGENTA_SDK_AUTH_CACHE: + policy = cache.get(_hash) - if not cached_policy: + if not policy: async with httpx.AsyncClient() as client: response = await client.get( f"{self.host}/api/permissions/verify", @@ -110,19 +118,17 @@ async def dispatch( cache.put(_hash, {"effect": "deny"}) return Deny() - cached_policy = { + policy = { "effect": "allow", "credentials": auth.get("credentials"), } - cache.put(_hash, cached_policy) + cache.put(_hash, policy) - if cached_policy.get("effect") == "deny": + if not policy or policy.get("effect") == "deny": return Deny() - request.state.credentials = cached_policy.get("credentials") - - print(f"credentials: {request.state.credentials}") + request.state.credentials = policy.get("credentials") return await call_next(request) diff --git a/agenta-cli/agenta/sdk/types.py b/agenta-cli/agenta/sdk/types.py index cab9fb4b2c..3852cee82d 100644 --- a/agenta-cli/agenta/sdk/types.py +++ b/agenta-cli/agenta/sdk/types.py @@ -24,8 +24,8 @@ class LLMTokenUsage(BaseModel): class BaseResponse(BaseModel): version: Optional[str] = "2.0" - data: Optional[Union[str, Dict[str, Any]]] - trace: Optional[Dict[str, Any]] + data: Optional[Union[str, Dict[str, Any]]] = None + trace: Optional[Dict[str, Any]] = None class DictInput(dict): diff --git a/agenta-web/src/services/api.ts b/agenta-web/src/services/api.ts index 93cf6e7fc9..0f26e1e144 100644 --- a/agenta-web/src/services/api.ts +++ b/agenta-web/src/services/api.ts @@ -1,4 +1,5 @@ import axios from "@/lib//helpers/axiosConfig" +import Session from "supertokens-auth-react/recipe/session" import {formatDay} from "@/lib/helpers/dateTimeHelper" import { detectChatVariantFromOpenAISchema, @@ -113,17 +114,36 @@ export async function callVariant( } const appContainerURI = await fetchAppContainerURL(appId, undefined, baseId) + const jwt = await getJWT() return axios .post(`${appContainerURI}/generate`, requestBody, { signal, _ignoreError: ignoreAxiosError, + headers: { + Authorization: jwt && `Bearer ${jwt}`, + }, } as any) .then((res) => { return res.data }) } +/** + * Get the JWT from SuperTokens + */ +const getJWT = async () => { + try { + if (await Session.doesSessionExist()) { + let jwt = await Session.getAccessToken() + + return jwt + } + } catch (error) {} + + return undefined +} + /** * Parses the openapi.json from a variant and returns the parameters as an array of objects. * @param app @@ -138,7 +158,13 @@ export const fetchVariantParametersFromOpenAPI = async ( ) => { const appContainerURI = await fetchAppContainerURL(appId, variantId, baseId) const url = `${appContainerURI}/openapi.json` - const response = await axios.get(url, {_ignoreError: ignoreAxiosError} as any) + const jwt = await getJWT() + const response = await axios.get(url, { + _ignoreError: ignoreAxiosError, + headers: { + Authorization: jwt && `Bearer ${jwt}`, + }, + } as any) const isChatVariant = detectChatVariantFromOpenAISchema(response.data) let APIParams = openAISchemaToParameters(response.data)