Skip to content

Commit

Permalink
Merge pull request #2266 from Agenta-AI/fix/test-sets-without-app-id
Browse files Browse the repository at this point in the history
[feat] global testsets
  • Loading branch information
jp-agenta authored Nov 22, 2024
2 parents ad20ab9 + 3169d8c commit 7fcc849
Show file tree
Hide file tree
Showing 47 changed files with 6,927 additions and 407 deletions.
8 changes: 4 additions & 4 deletions agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
response_model=List[str],
)
async def fetch_evaluation_ids(
app_id: str,
resource_type: str,
request: Request,
resource_ids: List[str] = Query(None),
Expand All @@ -52,11 +51,10 @@ async def fetch_evaluation_ids(
List[str]: A list of evaluation ids.
"""
try:
app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
permission=Permission.VIEW_EVALUATION,
)
logger.debug(
Expand All @@ -70,7 +68,9 @@ async def fetch_evaluation_ids(
status_code=403,
)
evaluations = await db_manager.fetch_evaluations_by_resource(
resource_type, str(app.project_id), resource_ids
resource_type,
request.state.project_id,
resource_ids,
)
return list(map(lambda x: str(x.id), evaluations))
except Exception as exc:
Expand Down
96 changes: 56 additions & 40 deletions agenta-backend/agenta_backend/routers/testset_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ async def upload_file(
upload_type: str = Form(None),
file: UploadFile = File(...),
testset_name: Optional[str] = File(None),
app_id: str = Form(None),
):
"""
Uploads a CSV or JSON file and saves its data to MongoDB.
Expand All @@ -67,11 +66,10 @@ async def upload_file(
dict: The result of the upload process.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to upload Testset: {has_permission}")
Expand Down Expand Up @@ -114,7 +112,8 @@ async def upload_file(

try:
testset = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=document
project_id=request.state.project_id,
testset_data=document,
)
return TestSetSimpleResponse(
id=str(testset.id),
Expand All @@ -132,7 +131,6 @@ async def import_testset(
request: Request,
endpoint: str = Form(None),
testset_name: str = Form(None),
app_id: str = Form(None),
):
"""
Import JSON testset data from an endpoint and save it to MongoDB.
Expand All @@ -145,11 +143,10 @@ async def import_testset(
dict: The result of the import process.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to import Testset: {has_permission}")
Expand Down Expand Up @@ -180,7 +177,8 @@ async def import_testset(
document["csvdata"].append(row)

testset = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=document
project_id=request.state.project_id,
testset_data=document,
)
return TestSetSimpleResponse(
id=str(testset.id),
Expand All @@ -204,30 +202,30 @@ async def import_testset(


@router.post(
"/{app_id}/", response_model=TestSetSimpleResponse, operation_id="create_testset"
"/{app_id}",
response_model=TestSetSimpleResponse,
operation_id="deprecating_create_testset",
)
@router.post("/", response_model=TestSetSimpleResponse, operation_id="create_testset")
async def create_testset(
app_id: str,
csvdata: NewTestset,
request: Request,
):
"""
Create a testset with given name and app_name, save the testset to MongoDB.
Create a testset with given name, save the testset to MongoDB.
Args:
name (str): name of the test set.
app_name (str): name of the application.
testset (Dict[str, str]): test set data.
Returns:
str: The id of the test set created.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to create Testset: {has_permission}")
Expand All @@ -245,7 +243,8 @@ async def create_testset(
"csvdata": csvdata.csvdata,
}
testset_instance = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=testset_data
project_id=request.state.project_id,
testset_data=testset_data,
)
if testset_instance is not None:
return TestSetSimpleResponse(
Expand Down Expand Up @@ -315,7 +314,6 @@ async def update_testset(

@router.get("/", operation_id="get_testsets")
async def get_testsets(
app_id: str,
request: Request,
) -> List[TestSetOutputResponse]:
"""
Expand All @@ -328,34 +326,52 @@ async def get_testsets(
- `HTTPException` with status code 404 if no testsets are found.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
permission=Permission.VIEW_TESTSET,
try:
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_TESTSET,
)

logger.debug(
"User has Permission to view Testsets: %s",
has_permission,
)

if not has_permission:
error_msg = (
"You do not have permission to perform this action. "
+ "Please contact your organization admin."
)
logger.error(error_msg)

return JSONResponse(
status_code=403,
content={"detail": error_msg},
)

testsets = await db_manager.fetch_testsets_by_project_id(
project_id=request.state.project_id,
)
logger.debug(f"User has Permission to view Testsets: {has_permission}")
if not has_permission:
error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
logger.error(error_msg)
return JSONResponse(
{"detail": error_msg},
status_code=403,

return [
TestSetOutputResponse(
_id=str(testset.id), # type: ignore
name=testset.name,
created_at=str(testset.created_at),
updated_at=str(testset.updated_at),
)
for testset in testsets
]

testsets = await db_manager.fetch_testsets_by_project_id(
project_id=str(app.project_id)
)
return [
TestSetOutputResponse(
_id=str(testset.id), # type: ignore
name=testset.name,
created_at=str(testset.created_at),
updated_at=str(testset.updated_at),
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")

raise HTTPException(
status_code=500,
detail=str(e),
)
for testset in testsets
]


@router.get("/{testset_id}/", operation_id="get_single_testset")
Expand Down
34 changes: 1 addition & 33 deletions agenta-backend/agenta_backend/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,38 +1892,6 @@ async def remove_testsets(testset_ids: List[str]):
await session.commit()


async def remove_app_testsets(app_id: str):
"""Returns a list of testsets owned by an app.
Args:
app_id (str): The name of the app
Returns:
int: The number of testsets deleted
"""

# Find testsets owned by the app
deleted_count: int = 0

async with engine.session() as session:
result = await session.execute(
select(TestSetDB).filter_by(app_id=uuid.UUID(app_id))
)
testsets = result.scalars().all()

if len(testsets) == 0:
logger.info(f"No testsets found for app {app_id}")
return 0

for testset in testsets:
await session.delete(testset)
deleted_count += 1
logger.info(f"{deleted_count} testset(s) deleted for app {app_id}")

await session.commit()
return deleted_count


async def remove_base_from_db(base_id: str):
"""
Remove a base from the database.
Expand Down Expand Up @@ -2086,7 +2054,7 @@ async def fetch_testset_by_id(testset_id: str) -> Optional[TestSetDB]:
return testset


async def create_testset(app: AppDB, project_id: str, testset_data: Dict[str, Any]):
async def create_testset(project_id: str, testset_data: Dict[str, Any]):
"""
Creates a testset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,29 @@

@pytest.mark.asyncio
async def test_create_testset():
async with engine.session() as session:
result = await session.execute(
select(AppDB).filter_by(app_name="app_variant_test")
)
app = result.scalars().first()

payload = {
"name": "create_testset_main",
"csvdata": [
{
"country": "Comoros",
"correct_answer": "The capital of Comoros is Moroni",
},
{
"country": "Kyrgyzstan",
"correct_answer": "The capital of Kyrgyzstan is Bishkek",
},
{
"country": "Azerbaijan",
"correct_answer": "The capital of Azerbaijan is Baku",
},
],
}
response = await test_client.post(
f"{BACKEND_API_HOST}/testsets/{str(app.id)}/", json=payload
)
assert response.status_code == 200
assert response.json()["name"] == payload["name"]
payload = {
"name": "create_testset_main",
"csvdata": [
{
"country": "Comoros",
"correct_answer": "The capital of Comoros is Moroni",
},
{
"country": "Kyrgyzstan",
"correct_answer": "The capital of Kyrgyzstan is Bishkek",
},
{
"country": "Azerbaijan",
"correct_answer": "The capital of Azerbaijan is Baku",
},
],
}
response = await test_client.post(
f"{BACKEND_API_HOST}/testsets",
json=payload,
)
assert response.status_code == 200
assert response.json()["name"] == payload["name"]


@pytest.mark.asyncio
Expand Down Expand Up @@ -101,18 +96,10 @@ async def test_update_testset():

@pytest.mark.asyncio
async def test_get_testsets():
async with engine.session() as session:
result = await session.execute(
select(AppDB).filter_by(app_name="app_variant_test")
)
app = result.scalars().first()
response = await test_client.get(f"{BACKEND_API_HOST}/testsets")

response = await test_client.get(
f"{BACKEND_API_HOST}/testsets/?app_id={str(app.id)}"
)

assert response.status_code == 200
assert len(response.json()) == 2
assert response.status_code == 200
assert len(response.json()) == 2


@pytest.mark.asyncio()
Expand Down
Loading

0 comments on commit 7fcc849

Please sign in to comment.