Skip to content

Commit

Permalink
Merge pull request #2166 from Agenta-AI/feature/app-security
Browse files Browse the repository at this point in the history
[Feature] Application Security
  • Loading branch information
jp-agenta authored Nov 18, 2024
2 parents 17e1f4d + 3f51172 commit 5b6086f
Show file tree
Hide file tree
Showing 15 changed files with 430 additions and 18 deletions.
22 changes: 21 additions & 1 deletion .github/actions/check-app-accessibility/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ inputs:
log-file:
description: 'The path to the log file containing the serve output'
required: true
api-key:
description: 'The API key for accessing the app'
required: false
runs:
using: "composite"
steps:
Expand All @@ -28,8 +31,25 @@ runs:
status_code=$(curl --max-time 60 --write-out %{http_code} --silent --output /dev/null --verbose $APP_URL)
echo "Attempt $i: Status code: $status_code"
if [ "$status_code" -eq 200 ]; then
echo "$APP_URL is accessible"
echo "$APP_URL is accessible without API key"
exit 0
elif [ "$status_code" -eq 401 ]; then
if [ -z "${{ inputs.api-key }}" ]; then
echo "Error: $APP_URL is not accessible without API key"
sleep 10 # Wait before retrying
fi
status_code=$( \
curl --max-time 60 --write-out %{http_code} --silent --output /dev/null --verbose \
-H "Authorization: Bearer ${{ inputs.api-key }}" \
$APP_URL)
echo "Attempt $i: Status code: $status_code"
if [ "$status_code" -eq 200 ]; then
echo "$APP_URL is accessible with API key"
exit 0
else
echo "Attempt $i failed: $APP_URL is not accessible"
sleep 10 # Wait before retrying
fi
else
echo "Attempt $i failed: $APP_URL is not accessible"
sleep 10 # Wait before retrying
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/cli-commands-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ jobs:
uses: ./.github/actions/check-app-accessibility
with:
log-file: serve_output.log
api-key: ${{ secrets.AGENTA_API_KEY }}
continue-on-error: false

- name: Run agenta variant serve with overwrite
Expand Down
2 changes: 2 additions & 0 deletions agenta-backend/agenta_backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
bases_router,
configs_router,
health_router,
permissions_router,
)
from agenta_backend.open_api import open_api_tags_metadata
from agenta_backend.utils.common import isEE, isCloudProd, isCloudDev, isOss, isCloudEE
Expand Down Expand Up @@ -96,6 +97,7 @@ async def lifespan(application: FastAPI, cache=True):
app, allow_headers = cloud.extend_main(app)

app.include_router(health_router.router, prefix="/health")
app.include_router(permissions_router.router, prefix="/permissions")
app.include_router(user_profile.router, prefix="/profile")
app.include_router(app_router.router, prefix="/apps", tags=["Apps"])
app.include_router(variants_router.router, prefix="/variants", tags=["Variants"])
Expand Down
3 changes: 2 additions & 1 deletion agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ async def create_evaluation(

evaluate.delay(
app_id=payload.app_id,
project_id=str(app.project_id),
user_id=str(request.state.user_id),
project_id=str(request.state.project_id),
variant_id=variant_id,
evaluators_config_ids=payload.evaluators_configs,
testset_id=payload.testset_id,
Expand Down
100 changes: 100 additions & 0 deletions agenta-backend/agenta_backend/routers/permissions_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Optional
from uuid import UUID

from fastapi import Request, Query, HTTPException
from fastapi.responses import JSONResponse

from agenta_backend.utils.common import isCloudEE, isOss, APIRouter
from agenta_backend.services import db_manager

if isCloudEE():
from agenta_backend.commons.models.shared_models import Permission
from agenta_backend.commons.utils.permissions import check_action_access


class Allow(JSONResponse):
def __init__(
self,
credentials: Optional[str] = None,
) -> None:
super().__init__(
status_code=200,
content={
"effect": "allow",
"credentials": credentials,
},
)


class Deny(HTTPException):
def __init__(self) -> None:
super().__init__(
status_code=401,
detail="Unauthorized",
)


router = APIRouter()


@router.get(
"/verify",
operation_id="verify_permissions",
)
async def verify_permissions(
request: Request,
action: Optional[str] = Query(None),
resource_type: Optional[str] = Query(None),
resource_id: Optional[UUID] = Query(None),
):
try:
if isOss():
return Allow(None)

if not action or not resource_type or not resource_id:
raise Deny()

if isCloudEE():
permission = Permission(action)

# CHECK PERMISSION 1/2: ACTION
allow_action = await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=permission,
)

if not allow_action:
raise Deny()

# CHECK PERMISSION 2/2: RESOURCE
allow_resource = await check_resource_access(
project_id=UUID(request.state.project_id),
resource_id=resource_id,
resource_type=resource_type,
)

if not allow_resource:
raise Deny()

return Allow(request.state.credentials)

except Exception as exc: # pylint: disable=bare-except
raise Deny() from exc


async def check_resource_access(
project_id: UUID,
resource_id: UUID,
resource_type: str,
) -> bool:
resource_project_id = None

if resource_type == "application":
app = await db_manager.get_app_instance_by_id(app_id=str(resource_id))

resource_project_id = app.project_id

allow_resource = resource_project_id == project_id

return allow_resource
48 changes: 44 additions & 4 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from agenta_backend.models.shared_models import InvokationResult, Result, Error
from agenta_backend.utils import common

from agenta_backend.utils.common import isCloudEE

if isCloudEE():
from agenta_backend.cloud.services.auth_helper import sign_secret_token


# Set logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -85,7 +91,12 @@ async def make_payload(


async def invoke_app(
uri: str, datapoint: Any, parameters: Dict, openapi_parameters: List[Dict]
uri: str,
datapoint: Any,
parameters: Dict,
openapi_parameters: List[Dict],
user_id: str,
project_id: str,
) -> InvokationResult:
"""
Invokes an app for one datapoint using the openapi_parameters to determine
Expand All @@ -105,12 +116,25 @@ async def invoke_app(
"""
url = f"{uri}/generate"
payload = await make_payload(datapoint, parameters, openapi_parameters)

headers = None

if isCloudEE():
secret_token = await sign_secret_token(user_id, project_id, None)

headers = {"Authorization": f"Secret {secret_token}"}

async with aiohttp.ClientSession() as client:
app_response = {}

try:
logger.debug(f"Invoking app {uri} with payload {payload}")
response = await client.post(url, json=payload, timeout=900)
response = await client.post(
url,
json=payload,
headers=headers,
timeout=900,
)
app_response = await response.json()
response.raise_for_status()

Expand Down Expand Up @@ -174,6 +198,8 @@ async def run_with_retry(
max_retry_count: int,
retry_delay: int,
openapi_parameters: List[Dict],
user_id: str,
project_id: str,
) -> InvokationResult:
"""
Runs the specified app with retry mechanism.
Expand All @@ -195,7 +221,14 @@ async def run_with_retry(
last_exception = None
while retries < max_retry_count:
try:
result = await invoke_app(uri, input_data, parameters, openapi_parameters)
result = await invoke_app(
uri,
input_data,
parameters,
openapi_parameters,
user_id,
project_id,
)
return result
except aiohttp.ClientError as e:
last_exception = e
Expand Down Expand Up @@ -228,7 +261,12 @@ async def run_with_retry(


async def batch_invoke(
uri: str, testset_data: List[Dict], parameters: Dict, rate_limit_config: Dict
uri: str,
testset_data: List[Dict],
parameters: Dict,
rate_limit_config: Dict,
user_id: str,
project_id: str,
) -> List[InvokationResult]:
"""
Invokes the LLm apps in batches, processing the testset data.
Expand Down Expand Up @@ -273,6 +311,8 @@ async def run_batch(start_idx: int):
max_retries,
retry_delay,
openapi_parameters,
user_id,
project_id,
)
)
tasks.append(task)
Expand Down
3 changes: 3 additions & 0 deletions agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
def evaluate(
self,
app_id: str,
user_id: str,
project_id: str,
variant_id: str,
evaluators_config_ids: List[str],
Expand Down Expand Up @@ -136,6 +137,8 @@ def evaluate(
testset_db.csvdata, # type: ignore
app_variant_parameters, # type: ignore
rate_limit_config,
user_id,
project_id,
)
)

Expand Down
54 changes: 48 additions & 6 deletions agenta-backend/agenta_backend/tests/unit/test_llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ async def test_batch_invoke_success():
]

# Mock the response of invoke_app to always succeed
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
def invoke_app_side_effect(
uri,
datapoint,
parameters,
openapi_parameters,
user_id,
project_id,
):
return InvokationResult(
result=Result(type="text", value="Success", error=None),
latency=0.1,
Expand All @@ -56,7 +63,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)
results = await batch_invoke(
uri,
testset_data,
parameters,
rate_limit_config,
user_id="test_user",
project_id="test_project",
)

assert len(results) == 2
assert results[0].result.type == "text"
Expand Down Expand Up @@ -89,7 +103,14 @@ async def test_batch_invoke_retries_and_failure():
]

# Mock the response of invoke_app to always fail
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
def invoke_app_side_effect(
uri,
datapoint,
parameters,
openapi_parameters,
user_id,
project_id,
):
raise aiohttp.ClientError("Test Error")

mock_invoke_app.side_effect = invoke_app_side_effect
Expand All @@ -107,7 +128,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
"delay_between_batches": 5,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)
results = await batch_invoke(
uri,
testset_data,
parameters,
rate_limit_config,
user_id="test_user",
project_id="test_project",
)

assert len(results) == 2
assert results[0].result.type == "error"
Expand Down Expand Up @@ -140,7 +168,14 @@ async def test_batch_invoke_generic_exception():
]

# Mock the response of invoke_app to raise a generic exception
def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
def invoke_app_side_effect(
uri,
datapoint,
parameters,
openapi_parameters,
user_id,
project_id,
):
raise Exception("Generic Error")

mock_invoke_app.side_effect = invoke_app_side_effect
Expand All @@ -155,7 +190,14 @@ def invoke_app_side_effect(uri, datapoint, parameters, openapi_parameters):
"delay_between_batches": 1,
}

results = await batch_invoke(uri, testset_data, parameters, rate_limit_config)
results = await batch_invoke(
uri,
testset_data,
parameters,
rate_limit_config,
user_id="test_user",
project_id="test_project",
)

assert len(results) == 1
assert results[0].result.type == "error"
Expand Down
Loading

0 comments on commit 5b6086f

Please sign in to comment.