Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]: CORS + App Security hotfix #2283

Merged
merged 20 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions agenta-backend/agenta_backend/routers/permissions_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
class Deny(HTTPException):
def __init__(self) -> None:
super().__init__(
status_code=401,
detail="Unauthorized",
status_code=403,
detail="Forbidden",
)


Expand Down
28 changes: 22 additions & 6 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import traceback
import aiohttp
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional


from agenta_backend.models.shared_models import InvokationResult, Result, Error
Expand Down Expand Up @@ -296,7 +296,17 @@ async def batch_invoke(
list_of_app_outputs: List[
InvokationResult
] = [] # Outputs after running all batches
openapi_parameters = await get_parameters_from_openapi(uri + "/openapi.json")

headers = None
if isCloudEE():
secret_token = await sign_secret_token(user_id, project_id, None)

headers = {"Authorization": f"Secret {secret_token}"}

openapi_parameters = await get_parameters_from_openapi(
uri + "/openapi.json",
headers,
)

async def run_batch(start_idx: int):
tasks = []
Expand Down Expand Up @@ -336,7 +346,10 @@ async def run_batch(start_idx: int):
return list_of_app_outputs


async def get_parameters_from_openapi(uri: str) -> List[Dict]:
async def get_parameters_from_openapi(
uri: str,
headers: Optional[Dict[str, str]],
) -> List[Dict]:
"""
Parse the OpenAI schema of an LLM app to return list of parameters that it takes with their type as determined by the x-parameter
Args:
Expand All @@ -351,7 +364,7 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]:

"""

schema = await _get_openai_json_from_uri(uri)
schema = await _get_openai_json_from_uri(uri, headers)

try:
body_schema_name = (
Expand Down Expand Up @@ -381,9 +394,12 @@ async def get_parameters_from_openapi(uri: str) -> List[Dict]:
return parameters


async def _get_openai_json_from_uri(uri):
async def _get_openai_json_from_uri(
uri: str,
headers: Optional[Dict[str, str]],
):
async with aiohttp.ClientSession() as client:
resp = await client.get(uri, timeout=5)
resp = await client.get(uri, headers=headers, timeout=5)
resp_text = await resp.text()
json_data = json.loads(resp_text)
return json_data
18 changes: 17 additions & 1 deletion agenta-backend/agenta_backend/tasks/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from celery import shared_task, states

from agenta_backend.utils.common import isCloudEE

if isCloudEE():
from agenta_backend.cloud.services.auth_helper import sign_secret_token

from agenta_backend.services import (
evaluators_service,
llm_apps_service,
Expand Down Expand Up @@ -143,8 +147,20 @@ def evaluate(
)

# 4. Evaluate the app outputs
secret_token = None
headers = None
if isCloudEE():
secret_token = loop.run_until_complete(
sign_secret_token(user_id, project_id, None)
)
if secret_token:
headers = {"Authorization": f"Secret {secret_token}"}

openapi_parameters = loop.run_until_complete(
llm_apps_service.get_parameters_from_openapi(uri + "/openapi.json")
llm_apps_service.get_parameters_from_openapi(
uri + "/openapi.json",
headers,
),
)

for data_point, app_output in zip(testset_db.csvdata, app_outputs): # type: ignore
Expand Down
31 changes: 17 additions & 14 deletions agenta-cli/agenta/sdk/decorators/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,21 @@

import agenta as ag

app = FastAPI()

origins = [
"*",
]

app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
AGENTA_USE_CORS = str(environ.get("AGENTA_USE_CORS", "true")).lower() in (
"true",
"1",
"t",
)

_MIDDLEWARES = True
app = FastAPI()
log.setLevel("DEBUG")


app.include_router(router, prefix="")
_MIDDLEWARES = True


log.setLevel("DEBUG")
app.include_router(router, prefix="")


class PathValidator(BaseModel):
Expand Down Expand Up @@ -137,6 +131,15 @@ def __init__(
resource_type="application",
)

if AGENTA_USE_CORS:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)

_MIDDLEWARES = False

except: # pylint: disable=bare-except
Expand Down
22 changes: 14 additions & 8 deletions agenta-cli/agenta/sdk/middleware/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
15 * 60, # 15 minutes
)

AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in (
"true",
"1",
"t",
)

AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str(
environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False)
).lower() in ("true", "1", "t")
Expand Down Expand Up @@ -89,9 +95,11 @@ async def dispatch(
sort_keys=True,
)

cached_policy = cache.get(_hash)
policy = None
if AGENTA_SDK_AUTH_CACHE:
policy = cache.get(_hash)

if not cached_policy:
if not policy:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.host}/api/permissions/verify",
Expand All @@ -110,19 +118,17 @@ async def dispatch(
cache.put(_hash, {"effect": "deny"})
return Deny()

cached_policy = {
policy = {
"effect": "allow",
"credentials": auth.get("credentials"),
}

cache.put(_hash, cached_policy)
cache.put(_hash, policy)

if cached_policy.get("effect") == "deny":
if policy.get("effect") == "deny":
return Deny()

request.state.credentials = cached_policy.get("credentials")

print(f"credentials: {request.state.credentials}")
request.state.credentials = policy.get("credentials")

return await call_next(request)

Expand Down
4 changes: 2 additions & 2 deletions agenta-cli/agenta/sdk/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ class LLMTokenUsage(BaseModel):

class BaseResponse(BaseModel):
version: Optional[str] = "2.0"
data: Optional[Union[str, Dict[str, Any]]]
trace: Optional[Dict[str, Any]]
data: Optional[Union[str, Dict[str, Any]]] = None
trace: Optional[Dict[str, Any]] = None


class DictInput(dict):
Expand Down
26 changes: 25 additions & 1 deletion agenta-web/src/services/api.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import axios from "@/lib//helpers/axiosConfig"
import Session from "supertokens-auth-react/recipe/session"
import {formatDay} from "@/lib/helpers/dateTimeHelper"
import {
detectChatVariantFromOpenAISchema,
Expand Down Expand Up @@ -113,17 +114,34 @@ export async function callVariant(
}

const appContainerURI = await fetchAppContainerURL(appId, undefined, baseId)
const jwt = await getJWT()

return axios
.post(`${appContainerURI}/generate`, requestBody, {
signal,
_ignoreError: ignoreAxiosError,
headers: {
Authorization: `Bearer ${jwt}`,
},
} as any)
.then((res) => {
return res.data
})
}

/**
* Get the JWT from SuperTokens
*/
const getJWT = async () => {
if (await Session.doesSessionExist()) {
let jwt = await Session.getAccessToken()

return jwt
}

return undefined
}

/**
* Parses the openapi.json from a variant and returns the parameters as an array of objects.
* @param app
Expand All @@ -138,7 +156,13 @@ export const fetchVariantParametersFromOpenAPI = async (
) => {
const appContainerURI = await fetchAppContainerURL(appId, variantId, baseId)
const url = `${appContainerURI}/openapi.json`
const response = await axios.get(url, {_ignoreError: ignoreAxiosError} as any)
const jwt = await getJWT()
const response = await axios.get(url, {
_ignoreError: ignoreAxiosError,
headers: {
Authorization: `Bearer ${jwt}`,
},
} as any)
const isChatVariant = detectChatVariantFromOpenAISchema(response.data)
let APIParams = openAISchemaToParameters(response.data)

Expand Down
Loading