diff --git a/.github/actions/check-app-accessibility/action.yml b/.github/actions/check-app-accessibility/action.yml index 76d1ce90f8..ea58602f83 100644 --- a/.github/actions/check-app-accessibility/action.yml +++ b/.github/actions/check-app-accessibility/action.yml @@ -4,6 +4,9 @@ inputs: log-file: description: 'The path to the log file containing the serve output' required: true + api-key: + description: 'The API key for accessing the app' + required: false runs: using: "composite" steps: @@ -28,8 +31,25 @@ runs: status_code=$(curl --max-time 60 --write-out %{http_code} --silent --output /dev/null --verbose $APP_URL) echo "Attempt $i: Status code: $status_code" if [ "$status_code" -eq 200 ]; then - echo "$APP_URL is accessible" + echo "$APP_URL is accessible without API key" exit 0 + elif [ "$status_code" -eq 401 ]; then + if [ -z "${{ inputs.api-key }}" ]; then + echo "Error: $APP_URL is not accessible without API key" + sleep 10 # Wait before retrying + fi + status_code=$( \ + curl --max-time 60 --write-out %{http_code} --silent --output /dev/null --verbose \ + -H "Authorization: Bearer ${{ inputs.api-key }}" \ + $APP_URL) + echo "Attempt $i: Status code: $status_code" + if [ "$status_code" -eq 200 ]; then + echo "$APP_URL is accessible with API key" + exit 0 + else + echo "Attempt $i failed: $APP_URL is not accessible" + sleep 10 # Wait before retrying + fi else echo "Attempt $i failed: $APP_URL is not accessible" sleep 10 # Wait before retrying diff --git a/.github/workflows/cli-commands-tests.yml b/.github/workflows/cli-commands-tests.yml index 5c43ef79ff..e6e539e800 100644 --- a/.github/workflows/cli-commands-tests.yml +++ b/.github/workflows/cli-commands-tests.yml @@ -104,6 +104,7 @@ jobs: uses: ./.github/actions/check-app-accessibility with: log-file: serve_output.log + api-key: ${{ secrets.AGENTA_API_KEY }} continue-on-error: false - name: Run agenta variant serve with overwrite diff --git a/agenta-backend/agenta_backend/main.py b/agenta-backend/agenta_backend/main.py index ae30c91122..87db329c97 100644 --- a/agenta-backend/agenta_backend/main.py +++ b/agenta-backend/agenta_backend/main.py @@ -14,6 +14,7 @@ bases_router, configs_router, health_router, + permissions_router, ) from agenta_backend.open_api import open_api_tags_metadata from agenta_backend.utils.common import isEE, isCloudProd, isCloudDev, isOss, isCloudEE @@ -96,6 +97,7 @@ async def lifespan(application: FastAPI, cache=True): app, allow_headers = cloud.extend_main(app) app.include_router(health_router.router, prefix="/health") +app.include_router(permissions_router.router, prefix="/permissions") app.include_router(user_profile.router, prefix="/profile") app.include_router(app_router.router, prefix="/apps", tags=["Apps"]) app.include_router(variants_router.router, prefix="/variants", tags=["Variants"]) diff --git a/agenta-backend/agenta_backend/routers/evaluation_router.py b/agenta-backend/agenta_backend/routers/evaluation_router.py index d3cd298181..454aedff37 100644 --- a/agenta-backend/agenta_backend/routers/evaluation_router.py +++ b/agenta-backend/agenta_backend/routers/evaluation_router.py @@ -130,7 +130,8 @@ async def create_evaluation( evaluate.delay( app_id=payload.app_id, - project_id=str(app.project_id), + user_id=str(request.state.user_id), + project_id=str(request.state.project_id), variant_id=variant_id, evaluators_config_ids=payload.evaluators_configs, testset_id=payload.testset_id, diff --git a/agenta-backend/agenta_backend/routers/permissions_router.py b/agenta-backend/agenta_backend/routers/permissions_router.py new file mode 100644 index 0000000000..01410f2b21 --- /dev/null +++ b/agenta-backend/agenta_backend/routers/permissions_router.py @@ -0,0 +1,100 @@ +from typing import Optional +from uuid import UUID + +from fastapi import Request, Query, HTTPException +from fastapi.responses import JSONResponse + +from agenta_backend.utils.common import isCloudEE, isOss, APIRouter +from agenta_backend.services import db_manager + +if isCloudEE(): + from agenta_backend.commons.models.shared_models import Permission + from agenta_backend.commons.utils.permissions import check_action_access + + +class Allow(JSONResponse): + def __init__( + self, + credentials: Optional[str] = None, + ) -> None: + super().__init__( + status_code=200, + content={ + "effect": "allow", + "credentials": credentials, + }, + ) + + +class Deny(HTTPException): + def __init__(self) -> None: + super().__init__( + status_code=401, + detail="Unauthorized", + ) + + +router = APIRouter() + + +@router.get( + "/verify", + operation_id="verify_permissions", +) +async def verify_permissions( + request: Request, + action: Optional[str] = Query(None), + resource_type: Optional[str] = Query(None), + resource_id: Optional[UUID] = Query(None), +): + try: + if isOss(): + return Allow(None) + + if not action or not resource_type or not resource_id: + raise Deny() + + if isCloudEE(): + permission = Permission(action) + + # CHECK PERMISSION 1/2: ACTION + allow_action = await check_action_access( + user_uid=request.state.user_id, + project_id=request.state.project_id, + permission=permission, + ) + + if not allow_action: + raise Deny() + + # CHECK PERMISSION 2/2: RESOURCE + allow_resource = await check_resource_access( + project_id=UUID(request.state.project_id), + resource_id=resource_id, + resource_type=resource_type, + ) + + if not allow_resource: + raise Deny() + + return Allow(request.state.credentials) + + except Exception as exc: # pylint: disable=bare-except + raise Deny() from exc + + +async def check_resource_access( + project_id: UUID, + resource_id: UUID, + resource_type: str, +) -> bool: + resource_project_id = None + + if resource_type == "application": + app = await db_manager.get_app_instance_by_id(app_id=str(resource_id)) + + resource_project_id = app.project_id + + allow_resource = resource_project_id == project_id + + return allow_resource diff --git a/agenta-backend/agenta_backend/services/llm_apps_service.py b/agenta-backend/agenta_backend/services/llm_apps_service.py index e6b1064fe3..341500acc3 100644 --- a/agenta-backend/agenta_backend/services/llm_apps_service.py +++ b/agenta-backend/agenta_backend/services/llm_apps_service.py @@ -9,6 +9,12 @@ from agenta_backend.models.shared_models import InvokationResult, Result, Error from agenta_backend.utils import common +from agenta_backend.utils.common import isCloudEE + +if isCloudEE(): + from agenta_backend.cloud.services.auth_helper import sign_secret_token + + # Set logger logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -85,7 +91,12 @@ async def make_payload( async def invoke_app( - uri: str, datapoint: Any, parameters: Dict, openapi_parameters: List[Dict] + uri: str, + datapoint: Any, + parameters: Dict, + openapi_parameters: List[Dict], + user_id: str, + project_id: str, ) -> InvokationResult: """ Invokes an app for one datapoint using the openapi_parameters to determine @@ -105,12 +116,25 @@ async def invoke_app( """ url = f"{uri}/generate" payload = await make_payload(datapoint, parameters, openapi_parameters) + + headers = None + + if isCloudEE(): + secret_token = await sign_secret_token(user_id, project_id, None) + + headers = {"Authorization": f"Secret {secret_token}"} + async with aiohttp.ClientSession() as client: app_response = {} try: logger.debug(f"Invoking app {uri} with payload {payload}") - response = await client.post(url, json=payload, timeout=900) + response = await client.post( + url, + json=payload, + headers=headers, + timeout=900, + ) app_response = await response.json() response.raise_for_status() @@ -174,6 +198,8 @@ async def run_with_retry( max_retry_count: int, retry_delay: int, openapi_parameters: List[Dict], + user_id: str, + project_id: str, ) -> InvokationResult: """ Runs the specified app with retry mechanism. @@ -195,7 +221,14 @@ async def run_with_retry( last_exception = None while retries < max_retry_count: try: - result = await invoke_app(uri, input_data, parameters, openapi_parameters) + result = await invoke_app( + uri, + input_data, + parameters, + openapi_parameters, + user_id, + project_id, + ) return result except aiohttp.ClientError as e: last_exception = e @@ -228,7 +261,12 @@ async def run_with_retry( async def batch_invoke( - uri: str, testset_data: List[Dict], parameters: Dict, rate_limit_config: Dict + uri: str, + testset_data: List[Dict], + parameters: Dict, + rate_limit_config: Dict, + user_id: str, + project_id: str, ) -> List[InvokationResult]: """ Invokes the LLm apps in batches, processing the testset data. @@ -273,6 +311,8 @@ async def run_batch(start_idx: int): max_retries, retry_delay, openapi_parameters, + user_id, + project_id, ) ) tasks.append(task) diff --git a/agenta-backend/agenta_backend/tasks/evaluations.py b/agenta-backend/agenta_backend/tasks/evaluations.py index c2388477e5..3b81fc712d 100644 --- a/agenta-backend/agenta_backend/tasks/evaluations.py +++ b/agenta-backend/agenta_backend/tasks/evaluations.py @@ -58,6 +58,7 @@ def evaluate( self, app_id: str, + user_id: str, project_id: str, variant_id: str, evaluators_config_ids: List[str], @@ -136,6 +137,8 @@ def evaluate( testset_db.csvdata, # type: ignore app_variant_parameters, # type: ignore rate_limit_config, + user_id, + project_id, ) ) diff --git a/agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py b/agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py index 3462f7c94c..54fefa8d09 100644 --- a/agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py +++ b/agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py @@ -34,7 +34,14 @@ async def test_batch_invoke_success(): ] # Mock the response of invoke_app to always succeed - def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): + def invoke_app_side_effect( + uri, + datapoint, + parameters, + openapi_parameters, + user_id, + project_id, + ): return InvokationResult( result=Result(type="text", value="Success", error=None), latency=0.1, @@ -56,7 +63,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): "delay_between_batches": 5, } - results = await batch_invoke(uri, testset_data, parameters, rate_limit_config) + results = await batch_invoke( + uri, + testset_data, + parameters, + rate_limit_config, + user_id="test_user", + project_id="test_project", + ) assert len(results) == 2 assert results[0].result.type == "text" @@ -89,7 +103,14 @@ async def test_batch_invoke_retries_and_failure(): ] # Mock the response of invoke_app to always fail - def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): + def invoke_app_side_effect( + uri, + datapoint, + parameters, + openapi_parameters, + user_id, + project_id, + ): raise aiohttp.ClientError("Test Error") mock_invoke_app.side_effect = invoke_app_side_effect @@ -107,7 +128,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): "delay_between_batches": 5, } - results = await batch_invoke(uri, testset_data, parameters, rate_limit_config) + results = await batch_invoke( + uri, + testset_data, + parameters, + rate_limit_config, + user_id="test_user", + project_id="test_project", + ) assert len(results) == 2 assert results[0].result.type == "error" @@ -140,7 +168,14 @@ async def test_batch_invoke_generic_exception(): ] # Mock the response of invoke_app to raise a generic exception - def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): + def invoke_app_side_effect( + uri, + datapoint, + parameters, + openapi_parameters, + user_id, + project_id, + ): raise Exception("Generic Error") mock_invoke_app.side_effect = invoke_app_side_effect @@ -155,7 +190,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters): "delay_between_batches": 1, } - results = await batch_invoke(uri, testset_data, parameters, rate_limit_config) + results = await batch_invoke( + uri, + testset_data, + parameters, + rate_limit_config, + user_id="test_user", + project_id="test_project", + ) assert len(results) == 1 assert results[0].result.type == "error" diff --git a/agenta-cli/agenta/sdk/decorators/routing.py b/agenta-cli/agenta/sdk/decorators/routing.py index 370a4f839c..4416cf8077 100644 --- a/agenta-cli/agenta/sdk/decorators/routing.py +++ b/agenta-cli/agenta/sdk/decorators/routing.py @@ -14,9 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi import Body, FastAPI, UploadFile, HTTPException +from agenta.sdk.middleware.auth import AuthorizationMiddleware from agenta.sdk.context.routing import routing_context_manager, routing_context from agenta.sdk.context.tracing import tracing_context -from agenta.sdk.router import router as router +from agenta.sdk.router import router from agenta.sdk.utils.exceptions import suppress from agenta.sdk.utils.logging import log from agenta.sdk.types import ( @@ -50,6 +51,9 @@ allow_headers=["*"], ) +_MIDDLEWARES = True + + app.include_router(router, prefix="") @@ -121,6 +125,26 @@ def __init__( route_path="", config_schema: Optional[BaseModel] = None, ): + ### --- Update Middleware --- # + try: + global _MIDDLEWARES # pylint: disable=global-statement + + if _MIDDLEWARES: + app.add_middleware( + AuthorizationMiddleware, + host=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.host, + resource_id=ag.DEFAULT_AGENTA_SINGLETON_INSTANCE.app_id, + resource_type="application", + ) + + _MIDDLEWARES = False + + except: # pylint: disable=bare-except + log.error("------------------------------------") + log.error("Agenta SDK - failed to secure route: %s", route_path) + log.error("------------------------------------") + ### --- Update Middleware --- # + DEFAULT_PATH = "generate" PLAYGROUND_PATH = "/playground" RUN_PATH = "/run" @@ -330,9 +354,9 @@ async def execute_function( *args, **func_params, ): - log.info(f"---------------------------") + log.info("---------------------------") log.info(f"Agenta SDK - running route: {repr(self.route_path or '/')}") - log.info(f"---------------------------") + log.info("---------------------------") tracing_context.set(routing_context.get()) diff --git a/agenta-cli/agenta/sdk/middleware/__init__.py b/agenta-cli/agenta/sdk/middleware/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py new file mode 100644 index 0000000000..7a017ef220 --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -0,0 +1,136 @@ +from typing import Callable, Optional +from os import environ +from uuid import UUID +from json import dumps +from traceback import format_exc + +import httpx +from starlette.middleware.base import BaseHTTPMiddleware +from fastapi import FastAPI, Request, Response + +from agenta.sdk.utils.logging import log +from agenta.sdk.middleware.cache import TTLLRUCache + +AGENTA_SDK_AUTH_CACHE_CAPACITY = environ.get( + "AGENTA_SDK_AUTH_CACHE_CAPACITY", + 512, +) + +AGENTA_SDK_AUTH_CACHE_TTL = environ.get( + "AGENTA_SDK_AUTH_CACHE_TTL", + 15 * 60, # 15 minutes +) + +AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( + environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) +).lower() in ("true", "1", "t") + + +class Deny(Response): + def __init__(self) -> None: + super().__init__(status_code=401, content="Unauthorized") + + +cache = TTLLRUCache( + capacity=AGENTA_SDK_AUTH_CACHE_CAPACITY, + ttl=AGENTA_SDK_AUTH_CACHE_TTL, +) + + +class AuthorizationMiddleware(BaseHTTPMiddleware): + def __init__( + self, + app: FastAPI, + host: str, + resource_id: UUID, + resource_type: str, + ): + super().__init__(app) + + self.host = host + self.resource_id = resource_id + self.resource_type = resource_type + + async def dispatch( + self, + request: Request, + call_next: Callable, + project_id: Optional[UUID] = None, + ): + if AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED: + return await call_next(request) + + try: + authorization = ( + request.headers.get("Authorization") + or request.headers.get("authorization") + or None + ) + + headers = {"Authorization": authorization} if authorization else None + + cookies = {"sAccessToken": request.cookies.get("sAccessToken")} + + params = { + "action": "run_service", + "resource_type": self.resource_type, + "resource_id": self.resource_id, + } + + if project_id: + params["project_id"] = project_id + + _hash = dumps( + { + "headers": headers, + "cookies": cookies, + "params": params, + }, + sort_keys=True, + ) + + cached_policy = cache.get(_hash) + + if not cached_policy: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.host}/api/permissions/verify", + headers=headers, + cookies=cookies, + params=params, + ) + + if response.status_code != 200: + cache.put(_hash, {"effect": "deny"}) + return Deny() + + auth = response.json() + + if auth.get("effect") != "allow": + cache.put(_hash, {"effect": "deny"}) + return Deny() + + cached_policy = { + "effect": "allow", + "credentials": auth.get("credentials"), + } + + cache.put(_hash, cached_policy) + + if cached_policy.get("effect") == "deny": + return Deny() + + request.state.credentials = cached_policy.get("credentials") + + print(f"credentials: {request.state.credentials}") + + return await call_next(request) + + except: # pylint: disable=bare-except + log.error("------------------------------------------------------") + log.error("Agenta SDK - handling auth middleware exception below:") + log.error("------------------------------------------------------") + log.error(format_exc().strip("\n")) + log.error("------------------------------------------------------") + + return Deny() diff --git a/agenta-cli/agenta/sdk/middleware/cache.py b/agenta-cli/agenta/sdk/middleware/cache.py new file mode 100644 index 0000000000..5445b1fafc --- /dev/null +++ b/agenta-cli/agenta/sdk/middleware/cache.py @@ -0,0 +1,43 @@ +from time import time +from collections import OrderedDict + + +class TTLLRUCache: + def __init__(self, capacity: int, ttl: int): + self.cache = OrderedDict() + self.capacity = capacity + self.ttl = ttl + + def get(self, key): + # CACHE + if key not in self.cache: + return None + + value, expiry = self.cache[key] + # ----- + + # TTL + if time() > expiry: + del self.cache[key] + + return None + # --- + + # LRU + self.cache.move_to_end(key) + # --- + + return value + + def put(self, key, value): + # CACHE + if key in self.cache: + del self.cache[key] + # CACHE & LRU + elif len(self.cache) >= self.capacity: + self.cache.popitem(last=False) + # ----------- + + # TTL + self.cache[key] = (value, time() + self.ttl) + # --- diff --git a/agenta-cli/agenta/sdk/tracing/exporters.py b/agenta-cli/agenta/sdk/tracing/exporters.py index 11c96df815..62f03a10b5 100644 --- a/agenta-cli/agenta/sdk/tracing/exporters.py +++ b/agenta-cli/agenta/sdk/tracing/exporters.py @@ -58,7 +58,7 @@ def fetch( return trace -OTLPSpanExporter._MAX_RETRY_TIMEOUT = 2 +OTLPSpanExporter._MAX_RETRY_TIMEOUT = 2 # pylint: disable=protected-access ConsoleExporter = ConsoleSpanExporter InlineExporter = InlineTraceExporter diff --git a/agenta-cli/agenta/sdk/tracing/processors.py b/agenta-cli/agenta/sdk/tracing/processors.py index 2d44b54179..8cd14d76d6 100644 --- a/agenta-cli/agenta/sdk/tracing/processors.py +++ b/agenta-cli/agenta/sdk/tracing/processors.py @@ -12,7 +12,7 @@ from agenta.sdk.utils.logging import log -# LOAD CONTEXT, HERE +# LOAD CONTEXT, HERE ! class TraceProcessor(BatchSpanProcessor): diff --git a/agenta-cli/pyproject.toml b/agenta-cli/pyproject.toml index af16d0ff9e..9f21b6db37 100644 --- a/agenta-cli/pyproject.toml +++ b/agenta-cli/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "agenta" -version = "0.27.3" +version = "0.27.4a0" description = "The SDK for agenta is an open-source LLMOps platform." readme = "README.md" authors = ["Mahmoud Mabrouk "]