Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] global testsets #2266

Merged
merged 21 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions agenta-backend/agenta_backend/routers/evaluation_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
response_model=List[str],
)
async def fetch_evaluation_ids(
app_id: str,
resource_type: str,
request: Request,
resource_ids: List[str] = Query(None),
Expand All @@ -52,11 +51,10 @@ async def fetch_evaluation_ids(
List[str]: A list of evaluation ids.
"""
try:
app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
permission=Permission.VIEW_EVALUATION,
)
logger.debug(
Expand All @@ -70,7 +68,9 @@ async def fetch_evaluation_ids(
status_code=403,
)
evaluations = await db_manager.fetch_evaluations_by_resource(
resource_type, str(app.project_id), resource_ids
resource_type,
request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
resource_ids,
)
return list(map(lambda x: str(x.id), evaluations))
except Exception as exc:
Expand Down
96 changes: 56 additions & 40 deletions agenta-backend/agenta_backend/routers/testset_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ async def upload_file(
upload_type: str = Form(None),
file: UploadFile = File(...),
testset_name: Optional[str] = File(None),
app_id: str = Form(None),
):
"""
Uploads a CSV or JSON file and saves its data to MongoDB.
Expand All @@ -67,11 +66,10 @@ async def upload_file(
dict: The result of the upload process.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to upload Testset: {has_permission}")
Expand Down Expand Up @@ -114,7 +112,8 @@ async def upload_file(

try:
testset = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=document
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
testset_data=document,
)
return TestSetSimpleResponse(
id=str(testset.id),
Expand All @@ -132,7 +131,6 @@ async def import_testset(
request: Request,
endpoint: str = Form(None),
testset_name: str = Form(None),
app_id: str = Form(None),
):
"""
Import JSON testset data from an endpoint and save it to MongoDB.
Expand All @@ -145,11 +143,10 @@ async def import_testset(
dict: The result of the import process.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to import Testset: {has_permission}")
Expand Down Expand Up @@ -180,7 +177,8 @@ async def import_testset(
document["csvdata"].append(row)

testset = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=document
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
testset_data=document,
)
return TestSetSimpleResponse(
id=str(testset.id),
Expand All @@ -204,30 +202,30 @@ async def import_testset(


@router.post(
"/{app_id}/", response_model=TestSetSimpleResponse, operation_id="create_testset"
"/{app_id}",
response_model=TestSetSimpleResponse,
operation_id="deprecating_create_testset",
)
@router.post("/", response_model=TestSetSimpleResponse, operation_id="create_testset")
async def create_testset(
app_id: str,
csvdata: NewTestset,
request: Request,
):
"""
Create a testset with given name and app_name, save the testset to MongoDB.
Create a testset with given name, save the testset to MongoDB.
Args:
name (str): name of the test set.
app_name (str): name of the application.
testset (Dict[str, str]): test set data.
Returns:
str: The id of the test set created.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
permission=Permission.CREATE_TESTSET,
)
logger.debug(f"User has Permission to create Testset: {has_permission}")
Expand All @@ -245,7 +243,8 @@ async def create_testset(
"csvdata": csvdata.csvdata,
}
testset_instance = await db_manager.create_testset(
app=app, project_id=str(app.project_id), testset_data=testset_data
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
project_id=request.state.project_id,
testset_data=testset_data,
)
if testset_instance is not None:
return TestSetSimpleResponse(
Expand Down Expand Up @@ -315,7 +314,6 @@ async def update_testset(

@router.get("/", operation_id="get_testsets")
async def get_testsets(
app_id: str,
request: Request,
) -> List[TestSetOutputResponse]:
"""
Expand All @@ -328,34 +326,52 @@ async def get_testsets(
- `HTTPException` with status code 404 if no testsets are found.
"""

app = await db_manager.fetch_app_by_id(app_id=app_id)
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=str(app.project_id),
permission=Permission.VIEW_TESTSET,
try:
if isCloudEE():
has_permission = await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_TESTSET,
)

logger.debug(
"User has Permission to view Testsets: %s",
has_permission,
)

if not has_permission:
error_msg = (
"You do not have permission to perform this action. "
+ "Please contact your organization admin."
)
logger.error(error_msg)

return JSONResponse(
status_code=403,
content={"detail": error_msg},
)

testsets = await db_manager.fetch_testsets_by_project_id(
project_id=request.state.project_id,
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
)
logger.debug(f"User has Permission to view Testsets: {has_permission}")
if not has_permission:
error_msg = f"You do not have permission to perform this action. Please contact your organization admin."
logger.error(error_msg)
return JSONResponse(
{"detail": error_msg},
status_code=403,

return [
TestSetOutputResponse(
_id=str(testset.id), # type: ignore
name=testset.name,
created_at=str(testset.created_at),
updated_at=str(testset.updated_at),
)
for testset in testsets
]

testsets = await db_manager.fetch_testsets_by_project_id(
project_id=str(app.project_id)
)
return [
TestSetOutputResponse(
_id=str(testset.id), # type: ignore
name=testset.name,
created_at=str(testset.created_at),
updated_at=str(testset.updated_at),
except Exception as e:
logger.exception(f"An error occurred: {str(e)}")

raise HTTPException(
status_code=500,
detail=str(e),
)
for testset in testsets
]


@router.get("/{testset_id}/", operation_id="get_single_testset")
Expand Down
34 changes: 1 addition & 33 deletions agenta-backend/agenta_backend/services/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,38 +1843,6 @@ async def remove_testsets(testset_ids: List[str]):
await session.commit()


async def remove_app_testsets(app_id: str):
"""Returns a list of testsets owned by an app.

Args:
app_id (str): The name of the app

Returns:
int: The number of testsets deleted
"""

# Find testsets owned by the app
deleted_count: int = 0

async with engine.session() as session:
result = await session.execute(
select(TestSetDB).filter_by(app_id=uuid.UUID(app_id))
)
testsets = result.scalars().all()

if len(testsets) == 0:
logger.info(f"No testsets found for app {app_id}")
return 0

for testset in testsets:
await session.delete(testset)
deleted_count += 1
logger.info(f"{deleted_count} testset(s) deleted for app {app_id}")

await session.commit()
return deleted_count


async def remove_base_from_db(base_id: str):
"""
Remove a base from the database.
Expand Down Expand Up @@ -2037,7 +2005,7 @@ async def fetch_testset_by_id(testset_id: str) -> Optional[TestSetDB]:
return testset


async def create_testset(app: AppDB, project_id: str, testset_data: Dict[str, Any]):
async def create_testset(project_id: str, testset_data: Dict[str, Any]):
"""
Creates a testset.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,29 @@

@pytest.mark.asyncio
async def test_create_testset():
async with engine.session() as session:
result = await session.execute(
select(AppDB).filter_by(app_name="app_variant_test")
)
app = result.scalars().first()

payload = {
"name": "create_testset_main",
"csvdata": [
{
"country": "Comoros",
"correct_answer": "The capital of Comoros is Moroni",
},
{
"country": "Kyrgyzstan",
"correct_answer": "The capital of Kyrgyzstan is Bishkek",
},
{
"country": "Azerbaijan",
"correct_answer": "The capital of Azerbaijan is Baku",
},
],
}
response = await test_client.post(
f"{BACKEND_API_HOST}/testsets/{str(app.id)}/", json=payload
)
assert response.status_code == 200
assert response.json()["name"] == payload["name"]
payload = {
"name": "create_testset_main",
"csvdata": [
{
"country": "Comoros",
"correct_answer": "The capital of Comoros is Moroni",
},
{
"country": "Kyrgyzstan",
"correct_answer": "The capital of Kyrgyzstan is Bishkek",
},
{
"country": "Azerbaijan",
"correct_answer": "The capital of Azerbaijan is Baku",
},
],
}
response = await test_client.post(
f"{BACKEND_API_HOST}/testsets",
json=payload,
)
assert response.status_code == 200
assert response.json()["name"] == payload["name"]


@pytest.mark.asyncio
Expand Down Expand Up @@ -101,18 +96,10 @@ async def test_update_testset():

@pytest.mark.asyncio
async def test_get_testsets():
async with engine.session() as session:
result = await session.execute(
select(AppDB).filter_by(app_name="app_variant_test")
)
app = result.scalars().first()
response = await test_client.get(f"{BACKEND_API_HOST}/testsets")

response = await test_client.get(
f"{BACKEND_API_HOST}/testsets/?app_id={str(app.id)}"
)

assert response.status_code == 200
assert len(response.json()) == 2
assert response.status_code == 200
assert len(response.json()) == 2


@pytest.mark.asyncio()
Expand Down
Loading
Loading