From 8d443ada5be4b56a5944fd2f59e727ff03ecbf06 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Sun, 1 Sep 2024 15:21:00 -0700 Subject: [PATCH] Integration tests (#2256) * initial commit * almost done * finished 3 tests * minor refactor * built out initial permisison tests * reworked test_deletion * removed logging * all original tests have been converted * renamed user_groups to user_group * mypy * added test for doc set permissions * unified naming for manager methods * Refactored models and added new deletion test * minor additions * better logging+fixed input variables * commented out failed tests * Added readme * readme update * Added auth to IT set auth_type to basic and require_email_verification to false * Update run-it.yml * used verify and added to readme * added api key manager --- .github/workflows/run-it.yml | 2 + backend/danswer/auth/users.py | 17 - .../danswer/db/connector_credential_pair.py | 12 +- backend/danswer/server/documents/cc_pair.py | 19 +- backend/danswer/server/documents/connector.py | 53 +-- .../danswer/server/documents/credential.py | 23 +- backend/danswer/server/documents/models.py | 14 +- .../server/features/document_set/api.py | 25 +- backend/danswer/server/manage/users.py | 4 +- backend/ee/danswer/db/user_group.py | 45 ++ backend/ee/danswer/server/user_group/api.py | 4 + backend/tests/integration/README.md | 70 +++ .../integration/common_utils/connectors.py | 114 ----- .../integration/common_utils/constants.py | 4 + .../integration/common_utils/document_sets.py | 30 -- backend/tests/integration/common_utils/llm.py | 106 +++-- .../common_utils/managers/api_key.py | 92 ++++ .../common_utils/managers/cc_pair.py | 202 +++++++++ .../common_utils/managers/connector.py | 124 ++++++ .../common_utils/managers/credential.py | 129 ++++++ .../common_utils/managers/document.py | 153 +++++++ .../common_utils/managers/document_set.py | 171 ++++++++ .../common_utils/managers/persona.py | 206 +++++++++ .../integration/common_utils/managers/user.py | 122 ++++++ .../common_utils/managers/user_group.py | 148 +++++++ .../tests/integration/common_utils/reset.py | 3 - .../common_utils/seed_documents.py | 72 --- .../integration/common_utils/test_models.py | 120 +++++ .../integration/common_utils/user_groups.py | 24 - backend/tests/integration/conftest.py | 20 + .../tests/connector/test_deletion.py | 413 +++++++++++------- .../tests/dev_apis/test_simple_chat_api.py | 41 +- .../tests/document_set/test_syncing.py | 110 +++-- .../permissions/test_cc_pair_permissions.py | 179 ++++++++ .../permissions/test_connector_permissions.py | 136 ++++++ .../test_credential_permissions.py | 108 +++++ .../permissions/test_doc_set_permissions.py | 186 ++++++++ .../permissions/test_user_role_permissions.py | 93 ++++ .../permissions/test_whole_curator_flow.py | 86 ++++ web/src/lib/credential.ts | 4 +- 40 files changed, 2881 insertions(+), 603 deletions(-) create mode 100644 backend/tests/integration/README.md delete mode 100644 backend/tests/integration/common_utils/connectors.py delete mode 100644 backend/tests/integration/common_utils/document_sets.py create mode 100644 backend/tests/integration/common_utils/managers/api_key.py create mode 100644 backend/tests/integration/common_utils/managers/cc_pair.py create mode 100644 backend/tests/integration/common_utils/managers/connector.py create mode 100644 backend/tests/integration/common_utils/managers/credential.py create mode 100644 backend/tests/integration/common_utils/managers/document.py create mode 100644 backend/tests/integration/common_utils/managers/document_set.py create mode 100644 backend/tests/integration/common_utils/managers/persona.py create mode 100644 backend/tests/integration/common_utils/managers/user.py create mode 100644 backend/tests/integration/common_utils/managers/user_group.py delete mode 100644 backend/tests/integration/common_utils/seed_documents.py create mode 100644 backend/tests/integration/common_utils/test_models.py delete mode 100644 backend/tests/integration/common_utils/user_groups.py create mode 100644 backend/tests/integration/tests/permissions/test_cc_pair_permissions.py create mode 100644 backend/tests/integration/tests/permissions/test_connector_permissions.py create mode 100644 backend/tests/integration/tests/permissions/test_credential_permissions.py create mode 100644 backend/tests/integration/tests/permissions/test_doc_set_permissions.py create mode 100644 backend/tests/integration/tests/permissions/test_user_role_permissions.py create mode 100644 backend/tests/integration/tests/permissions/test_whole_curator_flow.py diff --git a/.github/workflows/run-it.yml b/.github/workflows/run-it.yml index 7c0c1814c3b..45d57493b9a 100644 --- a/.github/workflows/run-it.yml +++ b/.github/workflows/run-it.yml @@ -92,6 +92,8 @@ jobs: run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ + AUTH_TYPE=basic \ + REQUIRE_EMAIL_VERIFICATION=false \ IMAGE_TAG=it \ docker compose -f docker-compose.dev.yml -p danswer-stack up -d --build id: start_docker diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index dff6a60363c..453e34cbe3d 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -67,23 +67,6 @@ logger = setup_logger() -def validate_curator_request(groups: list | None, is_public: bool) -> None: - if is_public: - detail = "Curators cannot create public objects" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - if not groups: - detail = "Curators must specify 1+ groups" - logger.error(detail) - raise HTTPException( - status_code=401, - detail=detail, - ) - - def is_user_admin(user: User | None) -> bool: if AUTH_TYPE == AuthType.DISABLED: return True diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 2a9d4cf76e5..f35aed9186c 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -334,9 +334,13 @@ def add_credential_to_connector( raise HTTPException(status_code=404, detail="Connector does not exist") if credential is None: + error_msg = ( + f"Credential {credential_id} does not exist or does not belong to user" + ) + logger.error(error_msg) raise HTTPException( status_code=401, - detail="Credential does not exist or does not belong to user", + detail=error_msg, ) existing_association = ( @@ -350,7 +354,7 @@ def add_credential_to_connector( if existing_association is not None: return StatusResponse( success=False, - message=f"Connector already has Credential {credential_id}", + message=f"Connector {connector_id} already has Credential {credential_id}", data=connector_id, ) @@ -374,8 +378,8 @@ def add_credential_to_connector( db_session.commit() return StatusResponse( - success=False, - message=f"Connector already has Credential {credential_id}", + success=True, + message=f"Creating new association between Connector {connector_id} and Credential {credential_id}", data=association.id, ) diff --git a/backend/danswer/server/documents/cc_pair.py b/backend/danswer/server/documents/cc_pair.py index 37c7cbfd0b1..99e7bae61f9 100644 --- a/backend/danswer/server/documents/cc_pair.py +++ b/backend/danswer/server/documents/cc_pair.py @@ -1,7 +1,6 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException -from pydantic import BaseModel from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -21,12 +20,13 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import get_index_attempts_for_connector from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CCPairFullInfo +from danswer.server.documents.models import CCStatusUpdateRequest from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorCredentialPairMetadata from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -84,10 +84,6 @@ def get_cc_pair_full_info( ) -class CCStatusUpdateRequest(BaseModel): - status: ConnectorCredentialPairStatus - - @router.put("/admin/cc-pair/{cc_pair_id}/status") def update_cc_pair_status( cc_pair_id: int, @@ -157,11 +153,12 @@ def associate_credential_to_connector( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: - if user and user.role != UserRole.ADMIN and metadata.is_public: - raise HTTPException( - status_code=400, - detail="Public connections cannot be created by non-admin users", - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=metadata.groups, + object_is_public=metadata.is_public, + ) try: response = add_credential_to_connector( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index a742e246491..52bec7d7fb7 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -75,7 +75,6 @@ from danswer.file_store.file_store import get_default_file_store from danswer.server.documents.models import AuthStatus from danswer.server.documents.models import AuthUrl -from danswer.server.documents.models import ConnectorBase from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.documents.models import ConnectorIndexingStatus from danswer.server.documents.models import ConnectorSnapshot @@ -93,6 +92,7 @@ from danswer.server.documents.models import RunConnectorRequest from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -514,35 +514,6 @@ def _validate_connector_allowed(source: DocumentSource) -> None: ) -def _check_connector_permissions( - connector_data: ConnectorUpdateRequest, user: User | None -) -> ConnectorBase: - """ - This is not a proper permission check, but this should prevent curators creating bad situations - until a long-term solution is implemented (Replacing CC pairs/Connectors with Connections) - """ - if user and user.role != UserRole.ADMIN: - if connector_data.is_public: - raise HTTPException( - status_code=400, - detail="Public connectors can only be created by admins", - ) - if not connector_data.groups: - raise HTTPException( - status_code=400, - detail="Connectors created by curators must have groups", - ) - return ConnectorBase( - name=connector_data.name, - source=connector_data.source, - input_type=connector_data.input_type, - connector_specific_config=connector_data.connector_specific_config, - refresh_freq=connector_data.refresh_freq, - prune_freq=connector_data.prune_freq, - indexing_start=connector_data.indexing_start, - ) - - @router.post("/admin/connector") def create_connector_from_model( connector_data: ConnectorUpdateRequest, @@ -551,13 +522,19 @@ def create_connector_from_model( ) -> ObjectCreationIdResponse: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) - + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() return create_connector( db_session=db_session, connector_data=connector_base, ) except ValueError as e: + logger.error(f"Error creating connector: {e}") raise HTTPException(status_code=400, detail=str(e)) @@ -608,12 +585,18 @@ def create_connector_with_mock_credential( def update_connector_from_model( connector_id: int, connector_data: ConnectorUpdateRequest, - user: User = Depends(current_admin_user), + user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: try: _validate_connector_allowed(connector_data.source) - connector_base = _check_connector_permissions(connector_data, user) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=connector_data.groups, + object_is_public=connector_data.is_public, + ) + connector_base = connector_data.to_connector_base() except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -643,7 +626,7 @@ def update_connector_from_model( @router.delete("/admin/connector/{connector_id}", response_model=StatusResponse[int]) def delete_connector_by_id( connector_id: int, - _: User = Depends(current_admin_user), + _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: try: diff --git a/backend/danswer/server/documents/credential.py b/backend/danswer/server/documents/credential.py index ba30b65f2f9..3d965481bf5 100644 --- a/backend/danswer/server/documents/credential.py +++ b/backend/danswer/server/documents/credential.py @@ -7,7 +7,6 @@ from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.credentials import alter_credential from danswer.db.credentials import create_credential from danswer.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE @@ -20,7 +19,6 @@ from danswer.db.engine import get_session from danswer.db.models import DocumentSource from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.documents.models import CredentialBase from danswer.server.documents.models import CredentialDataUpdateRequest from danswer.server.documents.models import CredentialSnapshot @@ -28,6 +26,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger +from ee.danswer.db.user_group import validate_user_creation_permissions logger = setup_logger() @@ -80,7 +79,7 @@ def get_cc_source_full_info( ] -@router.get("/credentials/{id}") +@router.get("/credential/{id}") def list_credentials_by_id( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), @@ -105,7 +104,7 @@ def delete_credential_by_id_admin( ) -@router.put("/admin/credentials/swap") +@router.put("/admin/credential/swap") def swap_credentials_for_connector( credential_swap_req: CredentialSwapRequest, user: User | None = Depends(current_user), @@ -131,14 +130,12 @@ def create_credential_from_model( user: User | None = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: - if ( - user - and user.role != UserRole.ADMIN - and not _ignore_credential_permissions(credential_info.source) - ): - validate_curator_request( - groups=credential_info.groups, - is_public=credential_info.curator_public, + if not _ignore_credential_permissions(credential_info.source): + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=credential_info.groups, + object_is_public=credential_info.curator_public, ) credential = create_credential(credential_info, user, db_session) @@ -179,7 +176,7 @@ def get_credential_by_id( return CredentialSnapshot.from_credential_db_model(credential) -@router.put("/admin/credentials/{credential_id}") +@router.put("/admin/credential/{credential_id}") def update_credential_data( credential_id: int, credential_update: CredentialDataUpdateRequest, diff --git a/backend/danswer/server/documents/models.py b/backend/danswer/server/documents/models.py index ba011afc196..3805ccc4157 100644 --- a/backend/danswer/server/documents/models.py +++ b/backend/danswer/server/documents/models.py @@ -48,9 +48,12 @@ class ConnectorBase(BaseModel): class ConnectorUpdateRequest(ConnectorBase): - is_public: bool | None = None + is_public: bool = True groups: list[int] = Field(default_factory=list) + def to_connector_base(self) -> ConnectorBase: + return ConnectorBase(**self.model_dump(exclude={"is_public", "groups"})) + class ConnectorSnapshot(ConnectorBase): id: int @@ -103,11 +106,6 @@ class CredentialSnapshot(CredentialBase): user_id: UUID | None time_created: datetime time_updated: datetime - name: str | None - source: DocumentSource - credential_json: dict[str, Any] - admin_public: bool - curator_public: bool @classmethod def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot": @@ -261,6 +259,10 @@ class ConnectorCredentialPairMetadata(BaseModel): groups: list[int] = Field(default_factory=list) +class CCStatusUpdateRequest(BaseModel): + status: ConnectorCredentialPairStatus + + class ConnectorCredentialPairDescriptor(BaseModel): id: int name: str | None = None diff --git a/backend/danswer/server/features/document_set/api.py b/backend/danswer/server/features/document_set/api.py index d1eff082891..c9cea2cf2a2 100644 --- a/backend/danswer/server/features/document_set/api.py +++ b/backend/danswer/server/features/document_set/api.py @@ -6,7 +6,6 @@ from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user -from danswer.auth.users import validate_curator_request from danswer.db.document_set import check_document_sets_are_public from danswer.db.document_set import fetch_all_document_sets_for_user from danswer.db.document_set import insert_document_set @@ -14,12 +13,12 @@ from danswer.db.document_set import update_document_set from danswer.db.engine import get_session from danswer.db.models import User -from danswer.db.models import UserRole from danswer.server.features.document_set.models import CheckDocSetPublicRequest from danswer.server.features.document_set.models import CheckDocSetPublicResponse from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.document_set.models import DocumentSetCreationRequest from danswer.server.features.document_set.models import DocumentSetUpdateRequest +from ee.danswer.db.user_group import validate_user_creation_permissions router = APIRouter(prefix="/manage") @@ -31,11 +30,12 @@ def create_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> int: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_creation_request.groups, - is_public=document_set_creation_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_creation_request.groups, + object_is_public=document_set_creation_request.is_public, + ) try: document_set_db_model, _ = insert_document_set( document_set_creation_request=document_set_creation_request, @@ -53,11 +53,12 @@ def patch_document_set( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: - if user and user.role != UserRole.ADMIN: - validate_curator_request( - groups=document_set_update_request.groups, - is_public=document_set_update_request.is_public, - ) + validate_user_creation_permissions( + db_session=db_session, + user=user, + target_group_ids=document_set_update_request.groups, + object_is_public=document_set_update_request.is_public, + ) try: update_document_set( document_set_update_request=document_set_update_request, diff --git a/backend/danswer/server/manage/users.py b/backend/danswer/server/manage/users.py index d2fd981b5b5..620ddd3b4b2 100644 --- a/backend/danswer/server/manage/users.py +++ b/backend/danswer/server/manage/users.py @@ -69,7 +69,7 @@ def set_user_role( if user_role_update_request.new_role == UserRole.CURATOR: raise HTTPException( - status_code=400, + status_code=402, detail="Curator role must be set via the User Group Menu", ) @@ -78,7 +78,7 @@ def set_user_role( if current_user.id == user_to_update.id: raise HTTPException( - status_code=400, + status_code=402, detail="An admin cannot demote themselves from admin role!", ) diff --git a/backend/ee/danswer/db/user_group.py b/backend/ee/danswer/db/user_group.py index 9d172c5d716..998587fa2b3 100644 --- a/backend/ee/danswer/db/user_group.py +++ b/backend/ee/danswer/db/user_group.py @@ -2,6 +2,7 @@ from operator import and_ from uuid import UUID +from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import select @@ -30,6 +31,50 @@ logger = setup_logger() +def validate_user_creation_permissions( + db_session: Session, + user: User | None, + target_group_ids: list[int] | None, + object_is_public: bool | None, +) -> None: + """ + All admin actions are allowed. + Prevents non-admins from creating/editing: + - public objects + - objects with no groups + - objects that belong to a group they don't curate + """ + if not user or user.role == UserRole.ADMIN: + return + + if object_is_public: + detail = "User does not have permission to create public credentials" + logger.error(detail) + raise HTTPException( + status_code=402, + detail=detail, + ) + if not target_group_ids: + detail = "Curators must specify 1+ groups" + logger.error(detail) + raise HTTPException( + status_code=402, + detail=detail, + ) + user_curated_groups = fetch_user_groups_for_user( + db_session=db_session, user_id=user.id, only_curator_groups=True + ) + user_curated_group_ids = set([group.id for group in user_curated_groups]) + target_group_ids_set = set(target_group_ids) + if not target_group_ids_set.issubset(user_curated_group_ids): + detail = "Curators cannot control groups they don't curate" + logger.error(detail) + raise HTTPException( + status_code=402, + detail=detail, + ) + + def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) return db_session.scalar(stmt) diff --git a/backend/ee/danswer/server/user_group/api.py b/backend/ee/danswer/server/user_group/api.py index e18487d5491..b33daddea64 100644 --- a/backend/ee/danswer/server/user_group/api.py +++ b/backend/ee/danswer/server/user_group/api.py @@ -9,6 +9,7 @@ from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.models import UserRole +from danswer.utils.logger import setup_logger from ee.danswer.db.user_group import fetch_user_groups from ee.danswer.db.user_group import fetch_user_groups_for_user from ee.danswer.db.user_group import insert_user_group @@ -20,6 +21,8 @@ from ee.danswer.server.user_group.models import UserGroupCreate from ee.danswer.server.user_group.models import UserGroupUpdate +logger = setup_logger() + router = APIRouter(prefix="/manage") @@ -90,6 +93,7 @@ def set_user_curator( set_curator_request=set_curator_request, ) except ValueError as e: + logger.error(f"Error setting user curator: {e}") raise HTTPException(status_code=404, detail=str(e)) diff --git a/backend/tests/integration/README.md b/backend/tests/integration/README.md new file mode 100644 index 00000000000..bc5e388082f --- /dev/null +++ b/backend/tests/integration/README.md @@ -0,0 +1,70 @@ +# Integration Tests + +## General Testing Overview +The integration tests are designed with a "manager" class and a "test" class for each type of object being manipulated (e.g., user, persona, credential): +- **Manager Class**: Contains methods for each type of API call. Responsible for creating, deleting, and verifying the existence of an entity. +- **Test Class**: Stores data for each entity being tested. This is our "expected state" of the object. + +The idea is that each test can use the manager class to create (.create()) a "test_" object. It can then perform an operation on the object (e.g., send a request to the API) and then check if the "test_" object is in the expected state by using the manager class (.verify()) function. + +## Instructions for Running Integration Tests Locally +1. Launch danswer (using Docker or running with a debugger), ensuring the API server is running on port 8080. + a. If you'd like to set environment variables, you can do so by creating a `.env` file in the danswer/backend/tests/integration/ directory. +2. Navigate to `danswer/backend`. +3. Run the following command in the terminal: + ```sh + pytest -s tests/integration/tests/ + ``` + or to run all tests in a file: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py + ``` + or to run a single test: + ```sh + pytest -s tests/integration/tests/path_to/test_file.py::test_function_name + ``` + +## Guidelines for Writing Integration Tests +- As authentication is currently required for all tests, each test should start by creating a user. +- Each test should ideally focus on a single API flow. +- The test writer should try to consider failure cases and edge cases for the flow and write the tests to check for these cases. +- Every step of the test should be commented describing what is being done and what the expected behavior is. +- A summary of the test should be given at the top of the test function as well! +- When writing new tests, manager classes, manager functions, and test classes, try to copy the style of the other ones that have already been written. +- Be careful for scope creep! + - No need to overcomplicate every test by verifying after every single API call so long as the case you would be verifying is covered elsewhere (ideally in a test focused on covering that case). + - An example of this is: Creating an admin user is done at the beginning of nearly every test, but we only need to verify that the user is actually an admin in the test focused on checking admin permissions. For every other test, we can just create the admin user and assume that the permissions are working as expected. + +## Current Testing Limitations +### Test coverage +- All tests are probably not as high coverage as they could be. +- The "connector" tests in particular are super bare bones because we will be reworking connector/cc_pair sometime soon. +- Global Curator role is not thoroughly tested. +- No auth is not tested at all. +### Failure checking +- While we test expected auth failures, we only check that it failed at all. +- We dont check that the return codes are what we expect. +- This means that a test could be failing for a different reason than expected. +- We should ensure that the proper codes are being returned for each failure case. +- We should also query the db after each failure to ensure that the db is in the expected state. +### Scope/focus +- The tests may be scoped sub-optimally. +- The scoping of each test may be overlapping. + +## Current Testing Coverage +The current testing coverage should be checked by reading the comments at the top of each test file. + + +## TODO: Testing Coverage +- Persona permissions testing +- Read only (and/or basic) user permissions + - Ensuring proper permission enforcement using the chat/doc_search endpoints +- No auth + +## Ideas for integration testing design +### Combine the "test" and "manager" classes +This could make test writing a bit cleaner by preventing test writers from having to pass around objects into functions that the objects have a 1:1 relationship with. + +### Rework VespaClient +Right now, its used a fixture and has to be passed around between manager classes. +Could just be built where its used diff --git a/backend/tests/integration/common_utils/connectors.py b/backend/tests/integration/common_utils/connectors.py deleted file mode 100644 index e7734cec3c8..00000000000 --- a/backend/tests/integration/common_utils/connectors.py +++ /dev/null @@ -1,114 +0,0 @@ -import uuid -from typing import cast - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from danswer.db.enums import ConnectorCredentialPairStatus -from tests.integration.common_utils.constants import API_SERVER_URL - - -class ConnectorCreationDetails(BaseModel): - connector_id: int - credential_id: int - cc_pair_id: int - - -class ConnectorClient: - @staticmethod - def create_connector( - name_prefix: str = "test_connector", credential_id: int | None = None - ) -> ConnectorCreationDetails: - unique_id = uuid.uuid4() - - connector_name = f"{name_prefix}_{unique_id}" - connector_data = { - "name": connector_name, - "source": DocumentSource.NOT_APPLICABLE, - "input_type": "load_state", - "connector_specific_config": {}, - "refresh_freq": 60, - "disabled": True, - } - response = requests.post( - f"{API_SERVER_URL}/manage/admin/connector", - json=connector_data, - ) - response.raise_for_status() - connector_id = response.json()["id"] - - # associate the credential with the connector - if not credential_id: - print("ID not specified, creating new credential") - # Create a new credential - credential_data = { - "credential_json": {}, - "admin_public": True, - "source": DocumentSource.NOT_APPLICABLE, - } - response = requests.post( - f"{API_SERVER_URL}/manage/credential", - json=credential_data, - ) - response.raise_for_status() - credential_id = cast(int, response.json()["id"]) - - cc_pair_metadata = {"name": f"test_cc_pair_{unique_id}", "is_public": True} - response = requests.put( - f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", - json=cc_pair_metadata, - ) - response.raise_for_status() - - # fetch the conenector credential pair id using the indexing status API - response = requests.get( - f"{API_SERVER_URL}/manage/admin/connector/indexing-status" - ) - response.raise_for_status() - indexing_statuses = response.json() - - cc_pair_id = None - for status in indexing_statuses: - if ( - status["connector"]["id"] == connector_id - and status["credential"]["id"] == credential_id - ): - cc_pair_id = status["cc_pair_id"] - break - - if cc_pair_id is None: - raise ValueError("Could not find the connector credential pair id") - - print( - f"Created connector with connector_id: {connector_id}, credential_id: {credential_id}, cc_pair_id: {cc_pair_id}" - ) - return ConnectorCreationDetails( - connector_id=int(connector_id), - credential_id=int(credential_id), - cc_pair_id=int(cc_pair_id), - ) - - @staticmethod - def update_connector_status( - cc_pair_id: int, status: ConnectorCredentialPairStatus - ) -> None: - response = requests.put( - f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/status", - json={"status": status}, - ) - response.raise_for_status() - - @staticmethod - def delete_connector(connector_id: int, credential_id: int) -> None: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/deletion-attempt", - json={"connector_id": connector_id, "credential_id": credential_id}, - ) - response.raise_for_status() - - @staticmethod - def get_connectors() -> list[dict]: - response = requests.get(f"{API_SERVER_URL}/manage/connector") - response.raise_for_status() - return response.json() diff --git a/backend/tests/integration/common_utils/constants.py b/backend/tests/integration/common_utils/constants.py index efc98dde7de..7d729191cf6 100644 --- a/backend/tests/integration/common_utils/constants.py +++ b/backend/tests/integration/common_utils/constants.py @@ -5,3 +5,7 @@ API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080" API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}" MAX_DELAY = 30 + +GENERAL_HEADERS = {"Content-Type": "application/json"} + +NUM_DOCS = 5 diff --git a/backend/tests/integration/common_utils/document_sets.py b/backend/tests/integration/common_utils/document_sets.py deleted file mode 100644 index dc898611108..00000000000 --- a/backend/tests/integration/common_utils/document_sets.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import cast - -import requests - -from danswer.server.features.document_set.models import DocumentSet -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.constants import API_SERVER_URL - - -class DocumentSetClient: - @staticmethod - def create_document_set( - doc_set_creation_request: DocumentSetCreationRequest, - ) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/document-set", - json=doc_set_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()) - - @staticmethod - def fetch_document_sets() -> list[DocumentSet]: - response = requests.get(f"{API_SERVER_URL}/manage/document-set") - response.raise_for_status() - - document_sets = [ - DocumentSet.parse_obj(doc_set_data) for doc_set_data in response.json() - ] - return document_sets diff --git a/backend/tests/integration/common_utils/llm.py b/backend/tests/integration/common_utils/llm.py index ba8b89d6b4d..f74b40073c9 100644 --- a/backend/tests/integration/common_utils/llm.py +++ b/backend/tests/integration/common_utils/llm.py @@ -1,62 +1,88 @@ import os -from typing import cast +from uuid import uuid4 import requests -from pydantic import BaseModel -from pydantic import PrivateAttr from danswer.server.manage.llm.models import LLMProviderUpsertRequest from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestLLMProvider +from tests.integration.common_utils.test_models import TestUser -class LLMProvider(BaseModel): - provider: str - api_key: str - default_model_name: str - api_base: str | None = None - api_version: str | None = None - is_default: bool = True +class LLMProviderManager: + @staticmethod + def create( + name: str | None = None, + provider: str | None = None, + api_key: str | None = None, + default_model_name: str | None = None, + api_base: str | None = None, + api_version: str | None = None, + groups: list[int] | None = None, + is_public: bool | None = None, + user_performing_action: TestUser | None = None, + ) -> TestLLMProvider: + print("Seeding LLM Providers...") - # only populated after creation - _provider_id: int | None = PrivateAttr() - - def create(self) -> int: llm_provider = LLMProviderUpsertRequest( - name=self.provider, - provider=self.provider, - default_model_name=self.default_model_name, - api_key=self.api_key, - api_base=self.api_base, - api_version=self.api_version, + name=name or f"test-provider-{uuid4()}", + provider=provider or "openai", + default_model_name=default_model_name or "gpt-4o-mini", + api_key=api_key or os.environ["OPENAI_API_KEY"], + api_base=api_base, + api_version=api_version, custom_config=None, - fast_default_model_name=None, - is_public=True, - groups=[], + fast_default_model_name=default_model_name or "gpt-4o-mini", + is_public=is_public or True, + groups=groups or [], display_model_names=None, model_names=None, ) - response = requests.put( + llm_response = requests.put( f"{API_SERVER_URL}/admin/llm/provider", - json=llm_provider.dict(), + json=llm_provider.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + llm_response.raise_for_status() + response_data = llm_response.json() + result_llm = TestLLMProvider( + id=response_data["id"], + name=response_data["name"], + provider=response_data["provider"], + api_key=response_data["api_key"], + default_model_name=response_data["default_model_name"], + is_public=response_data["is_public"], + groups=response_data["groups"], + api_base=response_data["api_base"], + api_version=response_data["api_version"], ) - response.raise_for_status() - self._provider_id = cast(int, response.json()["id"]) - return self._provider_id + set_default_response = requests.post( + f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + set_default_response.raise_for_status() - def delete(self) -> None: + return result_llm + + @staticmethod + def delete( + llm_provider: TestLLMProvider, + user_performing_action: TestUser | None = None, + ) -> bool: + if not llm_provider.id: + raise ValueError("LLM Provider ID is required to delete a provider") response = requests.delete( - f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}" + f"{API_SERVER_URL}/admin/llm/provider/{llm_provider.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, ) response.raise_for_status() - - -def seed_default_openai_provider() -> LLMProvider: - llm = LLMProvider( - provider="openai", - default_model_name="gpt-4o-mini", - api_key=os.environ["OPENAI_API_KEY"], - ) - llm.create() - return llm + return True diff --git a/backend/tests/integration/common_utils/managers/api_key.py b/backend/tests/integration/common_utils/managers/api_key.py new file mode 100644 index 00000000000..b6d2c29b732 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/api_key.py @@ -0,0 +1,92 @@ +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from ee.danswer.server.api_key.models import APIKeyArgs +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser + + +class APIKeyManager: + @staticmethod + def create( + name: str | None = None, + api_key_role: UserRole = UserRole.ADMIN, + user_performing_action: TestUser | None = None, + ) -> TestAPIKey: + name = f"{name}-api-key" if name else f"test-api-key-{uuid4()}" + api_key_request = APIKeyArgs( + name=name, + role=api_key_role, + ) + api_key_response = requests.post( + f"{API_SERVER_URL}/admin/api-key", + json=api_key_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + api_key = api_key_response.json() + result_api_key = TestAPIKey( + api_key_id=api_key["api_key_id"], + api_key_display=api_key["api_key_display"], + api_key=api_key["api_key"], + api_key_name=name, + api_key_role=api_key_role, + user_id=api_key["user_id"], + headers=GENERAL_HEADERS, + ) + result_api_key.headers["Authorization"] = f"Bearer {result_api_key.api_key}" + return result_api_key + + @staticmethod + def delete( + api_key: TestAPIKey, + user_performing_action: TestUser | None = None, + ) -> None: + api_key_response = requests.delete( + f"{API_SERVER_URL}/admin/api-key/{api_key.api_key_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestAPIKey]: + api_key_response = requests.get( + f"{API_SERVER_URL}/admin/api-key", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + api_key_response.raise_for_status() + return [TestAPIKey(**api_key) for api_key in api_key_response.json()] + + @staticmethod + def verify( + api_key: TestAPIKey, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + retrieved_keys = APIKeyManager.get_all( + user_performing_action=user_performing_action + ) + for key in retrieved_keys: + if key.api_key_id == api_key.api_key_id: + if verify_deleted: + raise ValueError("API Key found when it should have been deleted") + if ( + key.api_key_name == api_key.api_key_name + and key.api_key_role == api_key.api_key_role + ): + return + + if not verify_deleted: + raise Exception("API Key not found") diff --git a/backend/tests/integration/common_utils/managers/cc_pair.py b/backend/tests/integration/common_utils/managers/cc_pair.py new file mode 100644 index 00000000000..6498252bbe8 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/cc_pair.py @@ -0,0 +1,202 @@ +import time +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.db.enums import ConnectorCredentialPairStatus +from danswer.server.documents.models import ConnectorCredentialPairIdentifier +from danswer.server.documents.models import ConnectorIndexingStatus +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser + + +def _cc_pair_creator( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, +) -> TestCCPair: + name = f"{name}-cc-pair" if name else f"test-cc-pair-{uuid4()}" + + request = { + "name": name, + "is_public": is_public, + "groups": groups or [], + } + + response = requests.put( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}/credential/{credential_id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return TestCCPair( + id=response.json()["data"], + name=name, + connector_id=connector_id, + credential_id=credential_id, + is_public=is_public, + groups=groups or [], + ) + + +class CCPairManager: + @staticmethod + def create_from_scratch( + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + credential_json: dict[str, Any] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + connector = ConnectorManager.create( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + credential = CredentialManager.create( + credential_json=credential_json, + name=name, + source=source, + curator_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + return _cc_pair_creator( + connector_id=connector.id, + credential_id=credential.id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def create( + connector_id: int, + credential_id: int, + name: str | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCCPair: + return _cc_pair_creator( + connector_id=connector_id, + credential_id=credential_id, + name=name, + is_public=is_public, + groups=groups, + user_performing_action=user_performing_action, + ) + + @staticmethod + def pause_cc_pair( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + result = requests.put( + url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/status", + json={"status": "PAUSED"}, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def delete( + cc_pair: TestCCPair, + user_performing_action: TestUser | None = None, + ) -> None: + cc_pair_identifier = ConnectorCredentialPairIdentifier( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + ) + result = requests.post( + url=f"{API_SERVER_URL}/manage/admin/deletion-attempt", + json=cc_pair_identifier.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + result.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[ConnectorIndexingStatus]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/connector/indexing-status", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ConnectorIndexingStatus(**cc_pair) for cc_pair in response.json()] + + @staticmethod + def verify( + cc_pair: TestCCPair, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_cc_pairs = CCPairManager.get_all(user_performing_action) + for retrieved_cc_pair in all_cc_pairs: + if retrieved_cc_pair.cc_pair_id == cc_pair.id: + if verify_deleted: + # We assume that this check will be performed after the deletion is + # already waited for + raise ValueError( + f"CC pair {cc_pair.id} found but should be deleted" + ) + if ( + retrieved_cc_pair.name == cc_pair.name + and retrieved_cc_pair.connector.id == cc_pair.connector_id + and retrieved_cc_pair.credential.id == cc_pair.credential_id + and retrieved_cc_pair.public_doc == cc_pair.is_public + and set(retrieved_cc_pair.groups) == set(cc_pair.groups) + ): + return + + if not verify_deleted: + raise ValueError(f"CC pair {cc_pair.id} not found") + + @staticmethod + def wait_for_deletion_completion( + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + cc_pairs = CCPairManager.get_all(user_performing_action) + if all( + cc_pair.cc_pair_status != ConnectorCredentialPairStatus.DELETING + for cc_pair in cc_pairs + ): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"CC pairs deletion was not completed within the {MAX_DELAY} seconds" + ) + else: + print("Some CC pairs are still being deleted, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/managers/connector.py b/backend/tests/integration/common_utils/managers/connector.py new file mode 100644 index 00000000000..f72d079683b --- /dev/null +++ b/backend/tests/integration/common_utils/managers/connector.py @@ -0,0 +1,124 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.connectors.models import InputType +from danswer.server.documents.models import ConnectorUpdateRequest +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestConnector +from tests.integration.common_utils.test_models import TestUser + + +class ConnectorManager: + @staticmethod + def create( + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + input_type: InputType = InputType.LOAD_STATE, + connector_specific_config: dict[str, Any] | None = None, + is_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestConnector: + name = f"{name}-connector" if name else f"test-connector-{uuid4()}" + + connector_update_request = ConnectorUpdateRequest( + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + is_public=is_public, + groups=groups or [], + ) + + response = requests.post( + url=f"{API_SERVER_URL}/manage/admin/connector", + json=connector_update_request.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + response_data = response.json() + return TestConnector( + id=response_data.get("id"), + name=name, + source=source, + input_type=input_type, + connector_specific_config=connector_specific_config or {}, + groups=groups, + is_public=is_public, + ) + + @staticmethod + def edit( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.patch( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + json=connector.model_dump(exclude={"id"}), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + connector: TestConnector, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/admin/connector/{connector.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestConnector]: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) + for conn in response.json() + ] + + @staticmethod + def get( + connector_id: int, user_performing_action: TestUser | None = None + ) -> TestConnector: + response = requests.get( + url=f"{API_SERVER_URL}/manage/connector/{connector_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + conn = response.json() + return TestConnector( + id=conn.get("id"), + name=conn.get("name", ""), + source=conn.get("source", DocumentSource.FILE), + input_type=conn.get("input_type", InputType.LOAD_STATE), + connector_specific_config=conn.get("connector_specific_config", {}), + ) diff --git a/backend/tests/integration/common_utils/managers/credential.py b/backend/tests/integration/common_utils/managers/credential.py new file mode 100644 index 00000000000..c05cd1b5a3e --- /dev/null +++ b/backend/tests/integration/common_utils/managers/credential.py @@ -0,0 +1,129 @@ +from typing import Any +from uuid import uuid4 + +import requests + +from danswer.server.documents.models import CredentialSnapshot +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestCredential +from tests.integration.common_utils.test_models import TestUser + + +class CredentialManager: + @staticmethod + def create( + credential_json: dict[str, Any] | None = None, + admin_public: bool = True, + name: str | None = None, + source: DocumentSource = DocumentSource.FILE, + curator_public: bool = True, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestCredential: + name = f"{name}-credential" if name else f"test-credential-{uuid4()}" + + credential_request = { + "name": name, + "credential_json": credential_json or {}, + "admin_public": admin_public, + "source": source, + "curator_public": curator_public, + "groups": groups or [], + } + response = requests.post( + url=f"{API_SERVER_URL}/manage/credential", + json=credential_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + + response.raise_for_status() + return TestCredential( + id=response.json()["id"], + name=name, + credential_json=credential_json or {}, + admin_public=admin_public, + source=source, + curator_public=curator_public, + groups=groups or [], + ) + + @staticmethod + def edit( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + request = credential.model_dump(include={"name", "credential_json"}) + response = requests.put( + url=f"{API_SERVER_URL}/manage/admin/credential/{credential.id}", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def delete( + credential: TestCredential, + user_performing_action: TestUser | None = None, + ) -> None: + response = requests.delete( + url=f"{API_SERVER_URL}/manage/credential/{credential.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get( + credential_id: int, user_performing_action: TestUser | None = None + ) -> CredentialSnapshot: + response = requests.get( + url=f"{API_SERVER_URL}/manage/credential/{credential_id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return CredentialSnapshot(**response.json()) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[CredentialSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/manage/credential", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [CredentialSnapshot(**cred) for cred in response.json()] + + @staticmethod + def verify( + credential: TestCredential, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_credentials = CredentialManager.get_all(user_performing_action) + for fetched_credential in all_credentials: + if credential.id == fetched_credential.id: + if verify_deleted: + raise ValueError( + f"Credential {credential.id} found but should be deleted" + ) + if ( + credential.name == fetched_credential.name + and credential.admin_public == fetched_credential.admin_public + and credential.source == fetched_credential.source + and credential.curator_public == fetched_credential.curator_public + ): + return + if not verify_deleted: + raise ValueError(f"Credential {credential.id} not found") diff --git a/backend/tests/integration/common_utils/managers/document.py b/backend/tests/integration/common_utils/managers/document.py new file mode 100644 index 00000000000..3f691eca8f9 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document.py @@ -0,0 +1,153 @@ +from uuid import uuid4 + +import requests + +from danswer.configs.constants import DocumentSource +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import TestAPIKey +from tests.integration.common_utils.managers.cc_pair import TestCCPair +from tests.integration.common_utils.test_models import SimpleTestDocument +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.vespa import TestVespaClient + + +def _verify_document_permissions( + retrieved_doc: dict, + cc_pair: TestCCPair, + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, +) -> None: + acl_keys = set(retrieved_doc["access_control_list"].keys()) + print(f"ACL keys: {acl_keys}") + if cc_pair.is_public: + if "PUBLIC" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} is public but" + " does not have the PUBLIC ACL key" + ) + + if doc_creating_user is not None: + if f"user_id:{doc_creating_user.id}" not in acl_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} was created by user" + f" {doc_creating_user.id} but does not have the user_id:{doc_creating_user.id} ACL key" + ) + + if group_names is not None: + expected_group_keys = {f"group:{group_name}" for group_name in group_names} + found_group_keys = {key for key in acl_keys if key.startswith("group:")} + if found_group_keys != expected_group_keys: + raise ValueError( + f"Document {retrieved_doc['document_id']} has incorrect group ACL keys. Found: {found_group_keys}, \n" + f"Expected: {expected_group_keys}" + ) + + if doc_set_names is not None: + found_doc_set_names = set(retrieved_doc.get("document_sets", {}).keys()) + if found_doc_set_names != set(doc_set_names): + raise ValueError( + f"Document set names mismatch. \nFound: {found_doc_set_names}, \n" + f"Expected: {set(doc_set_names)}" + ) + + +def _generate_dummy_document(document_id: str, cc_pair_id: int) -> dict: + return { + "document": { + "id": document_id, + "sections": [ + { + "text": f"This is test document {document_id}", + "link": f"{document_id}", + } + ], + "source": DocumentSource.NOT_APPLICABLE, + # just for testing metadata + "metadata": {"document_id": document_id}, + "semantic_identifier": f"Test Document {document_id}", + "from_ingestion_api": True, + }, + "cc_pair_id": cc_pair_id, + } + + +class DocumentManager: + @staticmethod + def seed_and_attach_docs( + cc_pair: TestCCPair, + num_docs: int = NUM_DOCS, + document_ids: list[str] | None = None, + api_key: TestAPIKey | None = None, + ) -> TestCCPair: + # Use provided document_ids if available, otherwise generate random UUIDs + if document_ids is None: + document_ids = [f"test-doc-{uuid4()}" for _ in range(num_docs)] + else: + num_docs = len(document_ids) + # Create and ingest some documents + documents: list[dict] = [] + for document_id in document_ids: + document = _generate_dummy_document(document_id, cc_pair.id) + documents.append(document) + response = requests.post( + f"{API_SERVER_URL}/danswer-api/ingestion", + json=document, + headers=api_key.headers if api_key else GENERAL_HEADERS, + ) + response.raise_for_status() + + print("Seeding completed successfully.") + cc_pair.documents = [ + SimpleTestDocument( + id=document["document"]["id"], + content=document["document"]["sections"][0]["text"], + ) + for document in documents + ] + return cc_pair + + @staticmethod + def verify( + vespa_client: TestVespaClient, + cc_pair: TestCCPair, + # If None, will not check doc sets or groups + # If empty list, will check for empty doc sets or groups + doc_set_names: list[str] | None = None, + group_names: list[str] | None = None, + doc_creating_user: TestUser | None = None, + verify_deleted: bool = False, + ) -> None: + doc_ids = [document.id for document in cc_pair.documents] + retrieved_docs_dict = vespa_client.get_documents_by_id(doc_ids)["documents"] + retrieved_docs = { + doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict + } + # Left this here for debugging purposes. + # import json + # for doc in retrieved_docs.values(): + # printable_doc = doc.copy() + # print(printable_doc.keys()) + # printable_doc.pop("embeddings") + # printable_doc.pop("title_embedding") + # print(json.dumps(printable_doc, indent=2)) + + for document in cc_pair.documents: + retrieved_doc = retrieved_docs.get(document.id) + if not retrieved_doc: + if not verify_deleted: + raise ValueError(f"Document not found: {document.id}") + continue + if verify_deleted: + raise ValueError( + f"Document found when it should be deleted: {document.id}" + ) + _verify_document_permissions( + retrieved_doc, + cc_pair, + doc_set_names, + group_names, + doc_creating_user, + ) diff --git a/backend/tests/integration/common_utils/managers/document_set.py b/backend/tests/integration/common_utils/managers/document_set.py new file mode 100644 index 00000000000..8133ccc8712 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/document_set.py @@ -0,0 +1,171 @@ +import time +from uuid import uuid4 + +import requests + +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestDocumentSet +from tests.integration.common_utils.test_models import TestUser + + +class DocumentSetManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + cc_pair_ids: list[int] | None = None, + is_public: bool = True, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestDocumentSet: + if name is None: + name = f"test_doc_set_{str(uuid4())}" + + doc_set_creation_request = { + "name": name, + "description": description or name, + "cc_pair_ids": cc_pair_ids or [], + "is_public": is_public, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + return TestDocumentSet( + id=int(response.json()), + name=name, + description=description or name, + cc_pair_ids=cc_pair_ids or [], + is_public=is_public, + is_up_to_date=True, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + doc_set_update_request = { + "id": document_set.id, + "description": document_set.description, + "cc_pair_ids": document_set.cc_pair_ids, + "is_public": document_set.is_public, + "users": document_set.users, + "groups": document_set.groups, + } + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/document-set", + json=doc_set_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def delete( + document_set: TestDocumentSet, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/manage/admin/document-set/{document_set.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return True + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[TestDocumentSet]: + response = requests.get( + f"{API_SERVER_URL}/manage/document-set", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [ + TestDocumentSet( + id=doc_set["id"], + name=doc_set["name"], + description=doc_set["description"], + cc_pair_ids=[ + cc_pair["id"] for cc_pair in doc_set["cc_pair_descriptors"] + ], + is_public=doc_set["is_public"], + is_up_to_date=doc_set["is_up_to_date"], + users=doc_set["users"], + groups=doc_set["groups"], + ) + for doc_set in response.json() + ] + + @staticmethod + def wait_for_sync( + document_sets_to_check: list[TestDocumentSet] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + # wait for document sets to be synced + start = time.time() + while True: + doc_sets = DocumentSetManager.get_all(user_performing_action) + if document_sets_to_check: + check_ids = {doc_set.id for doc_set in document_sets_to_check} + doc_set_ids = {doc_set.id for doc_set in doc_sets} + if not check_ids.issubset(doc_set_ids): + raise RuntimeError("Document set not found") + doc_sets = [doc_set for doc_set in doc_sets if doc_set.id in check_ids] + all_up_to_date = all(doc_set.is_up_to_date for doc_set in doc_sets) + + if all_up_to_date: + break + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"Document sets were not synced within the {MAX_DELAY} seconds" + ) + else: + print("Document sets were not synced yet, waiting...") + + time.sleep(2) + + @staticmethod + def verify( + document_set: TestDocumentSet, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + doc_sets = DocumentSetManager.get_all(user_performing_action) + for doc_set in doc_sets: + if doc_set.id == document_set.id: + if verify_deleted: + raise ValueError( + f"Document set {document_set.id} found but should have been deleted" + ) + if ( + doc_set.name == document_set.name + and set(doc_set.cc_pair_ids) == set(document_set.cc_pair_ids) + and doc_set.is_public == document_set.is_public + and set(doc_set.users) == set(document_set.users) + and set(doc_set.groups) == set(document_set.groups) + ): + return + if not verify_deleted: + raise ValueError(f"Document set {document_set.id} not found") diff --git a/backend/tests/integration/common_utils/managers/persona.py b/backend/tests/integration/common_utils/managers/persona.py new file mode 100644 index 00000000000..41ff43edb6f --- /dev/null +++ b/backend/tests/integration/common_utils/managers/persona.py @@ -0,0 +1,206 @@ +from uuid import uuid4 + +import requests + +from danswer.search.enums import RecencyBiasSetting +from danswer.server.features.persona.models import PersonaSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestPersona +from tests.integration.common_utils.test_models import TestUser + + +class PersonaManager: + @staticmethod + def create( + name: str | None = None, + description: str | None = None, + num_chunks: float = 5, + llm_relevance_filter: bool = True, + is_public: bool = True, + llm_filter_extraction: bool = True, + recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + name = name or f"test-persona-{uuid4()}" + description = description or f"Description for {name}" + + persona_creation_request = { + "name": name, + "description": description, + "num_chunks": num_chunks, + "llm_relevance_filter": llm_relevance_filter, + "is_public": is_public, + "llm_filter_extraction": llm_filter_extraction, + "recency_bias": recency_bias, + "prompt_ids": prompt_ids or [], + "document_set_ids": document_set_ids or [], + "tool_ids": tool_ids or [], + "llm_model_provider_override": llm_model_provider_override, + "llm_model_version_override": llm_model_version_override, + "users": users or [], + "groups": groups or [], + } + + response = requests.post( + f"{API_SERVER_URL}/persona", + json=persona_creation_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + persona_data = response.json() + + return TestPersona( + id=persona_data["id"], + name=name, + description=description, + num_chunks=num_chunks, + llm_relevance_filter=llm_relevance_filter, + is_public=is_public, + llm_filter_extraction=llm_filter_extraction, + recency_bias=recency_bias, + prompt_ids=prompt_ids or [], + document_set_ids=document_set_ids or [], + tool_ids=tool_ids or [], + llm_model_provider_override=llm_model_provider_override, + llm_model_version_override=llm_model_version_override, + users=users or [], + groups=groups or [], + ) + + @staticmethod + def edit( + persona: TestPersona, + name: str | None = None, + description: str | None = None, + num_chunks: float | None = None, + llm_relevance_filter: bool | None = None, + is_public: bool | None = None, + llm_filter_extraction: bool | None = None, + recency_bias: RecencyBiasSetting | None = None, + prompt_ids: list[int] | None = None, + document_set_ids: list[int] | None = None, + tool_ids: list[int] | None = None, + llm_model_provider_override: str | None = None, + llm_model_version_override: str | None = None, + users: list[str] | None = None, + groups: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestPersona: + persona_update_request = { + "name": name or persona.name, + "description": description or persona.description, + "num_chunks": num_chunks or persona.num_chunks, + "llm_relevance_filter": llm_relevance_filter + or persona.llm_relevance_filter, + "is_public": is_public or persona.is_public, + "llm_filter_extraction": llm_filter_extraction + or persona.llm_filter_extraction, + "recency_bias": recency_bias or persona.recency_bias, + "prompt_ids": prompt_ids or persona.prompt_ids, + "document_set_ids": document_set_ids or persona.document_set_ids, + "tool_ids": tool_ids or persona.tool_ids, + "llm_model_provider_override": llm_model_provider_override + or persona.llm_model_provider_override, + "llm_model_version_override": llm_model_version_override + or persona.llm_model_version_override, + "users": users or persona.users, + "groups": groups or persona.groups, + } + + response = requests.patch( + f"{API_SERVER_URL}/persona/{persona.id}", + json=persona_update_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + updated_persona_data = response.json() + + return TestPersona( + id=updated_persona_data["id"], + name=updated_persona_data["name"], + description=updated_persona_data["description"], + num_chunks=updated_persona_data["num_chunks"], + llm_relevance_filter=updated_persona_data["llm_relevance_filter"], + is_public=updated_persona_data["is_public"], + llm_filter_extraction=updated_persona_data["llm_filter_extraction"], + recency_bias=updated_persona_data["recency_bias"], + prompt_ids=updated_persona_data["prompts"], + document_set_ids=updated_persona_data["document_sets"], + tool_ids=updated_persona_data["tools"], + llm_model_provider_override=updated_persona_data[ + "llm_model_provider_override" + ], + llm_model_version_override=updated_persona_data[ + "llm_model_version_override" + ], + users=[user["email"] for user in updated_persona_data["users"]], + groups=updated_persona_data["groups"], + ) + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[PersonaSnapshot]: + response = requests.get( + f"{API_SERVER_URL}/admin/persona", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [PersonaSnapshot(**persona) for persona in response.json()] + + @staticmethod + def verify( + test_persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + all_personas = PersonaManager.get_all(user_performing_action) + for persona in all_personas: + if persona.id == test_persona.id: + return ( + persona.name == test_persona.name + and persona.description == test_persona.description + and persona.num_chunks == test_persona.num_chunks + and persona.llm_relevance_filter + == test_persona.llm_relevance_filter + and persona.is_public == test_persona.is_public + and persona.llm_filter_extraction + == test_persona.llm_filter_extraction + and persona.llm_model_provider_override + == test_persona.llm_model_provider_override + and persona.llm_model_version_override + == test_persona.llm_model_version_override + and set(persona.prompts) == set(test_persona.prompt_ids) + and set(persona.document_sets) == set(test_persona.document_set_ids) + and set(persona.tools) == set(test_persona.tool_ids) + and set(user.email for user in persona.users) + == set(test_persona.users) + and set(persona.groups) == set(test_persona.groups) + ) + return False + + @staticmethod + def delete( + persona: TestPersona, + user_performing_action: TestUser | None = None, + ) -> bool: + response = requests.delete( + f"{API_SERVER_URL}/persona/{persona.id}", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + return response.ok diff --git a/backend/tests/integration/common_utils/managers/user.py b/backend/tests/integration/common_utils/managers/user.py new file mode 100644 index 00000000000..0946b8b1fca --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user.py @@ -0,0 +1,122 @@ +from copy import deepcopy +from urllib.parse import urlencode +from uuid import uuid4 + +import requests + +from danswer.db.models import UserRole +from danswer.server.manage.models import AllUsersResponse +from danswer.server.models import FullUserSnapshot +from danswer.server.models import InvitedUserSnapshot +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.test_models import TestUser + + +class UserManager: + @staticmethod + def create( + name: str | None = None, + ) -> TestUser: + if name is None: + name = f"test{str(uuid4())}" + + email = f"{name}@test.com" + password = "test" + + body = { + "email": email, + "username": email, + "password": password, + } + response = requests.post( + url=f"{API_SERVER_URL}/auth/register", + json=body, + headers=GENERAL_HEADERS, + ) + response.raise_for_status() + + test_user = TestUser( + id=response.json()["id"], + email=email, + password=password, + headers=deepcopy(GENERAL_HEADERS), + ) + print(f"Created user {test_user.email}") + + test_user.headers["Cookie"] = UserManager.login_as_user(test_user) + + return test_user + + @staticmethod + def login_as_user(test_user: TestUser) -> str: + data = urlencode( + { + "username": test_user.email, + "password": test_user.password, + } + ) + headers = test_user.headers.copy() + headers["Content-Type"] = "application/x-www-form-urlencoded" + + response = requests.post( + url=f"{API_SERVER_URL}/auth/login", + data=data, + headers=headers, + ) + response.raise_for_status() + result_cookie = next(iter(response.cookies), None) + + if not result_cookie: + raise Exception("Failed to login") + + print(f"Logged in as {test_user.email}") + return f"{result_cookie.name}={result_cookie.value}" + + @staticmethod + def verify_role(user_to_verify: TestUser, target_role: UserRole) -> bool: + response = requests.get( + url=f"{API_SERVER_URL}/me", + headers=user_to_verify.headers, + ) + response.raise_for_status() + return target_role == UserRole(response.json().get("role", "")) + + @staticmethod + def set_role( + user_to_set: TestUser, + target_role: UserRole, + user_to_perform_action: TestUser | None = None, + ) -> None: + if user_to_perform_action is None: + user_to_perform_action = user_to_set + response = requests.patch( + url=f"{API_SERVER_URL}/manage/set-user-role", + json={"user_email": user_to_set.email, "new_role": target_role.value}, + headers=user_to_perform_action.headers, + ) + response.raise_for_status() + + @staticmethod + def verify(user: TestUser, user_to_perform_action: TestUser | None = None) -> None: + if user_to_perform_action is None: + user_to_perform_action = user + response = requests.get( + url=f"{API_SERVER_URL}/manage/users", + headers=user_to_perform_action.headers + if user_to_perform_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + data = response.json() + all_users = AllUsersResponse( + accepted=[FullUserSnapshot(**user) for user in data["accepted"]], + invited=[InvitedUserSnapshot(**user) for user in data["invited"]], + accepted_pages=data["accepted_pages"], + invited_pages=data["invited_pages"], + ) + for accepted_user in all_users.accepted: + if accepted_user.email == user.email and accepted_user.id == user.id: + return + raise ValueError(f"User {user.email} not found") diff --git a/backend/tests/integration/common_utils/managers/user_group.py b/backend/tests/integration/common_utils/managers/user_group.py new file mode 100644 index 00000000000..5f5ac6b0e30 --- /dev/null +++ b/backend/tests/integration/common_utils/managers/user_group.py @@ -0,0 +1,148 @@ +import time +from uuid import uuid4 + +import requests + +from ee.danswer.server.user_group.models import UserGroup +from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS +from tests.integration.common_utils.constants import MAX_DELAY +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup + + +class UserGroupManager: + @staticmethod + def create( + name: str | None = None, + user_ids: list[str] | None = None, + cc_pair_ids: list[int] | None = None, + user_performing_action: TestUser | None = None, + ) -> TestUserGroup: + name = f"{name}-user-group" if name else f"test-user-group-{uuid4()}" + + request = { + "name": name, + "user_ids": user_ids or [], + "cc_pair_ids": cc_pair_ids or [], + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group", + json=request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + test_user_group = TestUserGroup( + id=response.json()["id"], + name=response.json()["name"], + user_ids=[user["id"] for user in response.json()["users"]], + cc_pair_ids=[cc_pair["id"] for cc_pair in response.json()["cc_pairs"]], + ) + return test_user_group + + @staticmethod + def edit( + user_group: TestUserGroup, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_group.id: + raise ValueError("User group has no ID") + response = requests.patch( + f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}", + json=user_group.model_dump(), + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def set_curator_status( + test_user_group: TestUserGroup, + user_to_set_as_curator: TestUser, + is_curator: bool = True, + user_performing_action: TestUser | None = None, + ) -> None: + if not user_to_set_as_curator.id: + raise ValueError("User has no ID") + set_curator_request = { + "user_id": user_to_set_as_curator.id, + "is_curator": is_curator, + } + response = requests.post( + f"{API_SERVER_URL}/manage/admin/user-group/{test_user_group.id}/set-curator", + json=set_curator_request, + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + + @staticmethod + def get_all( + user_performing_action: TestUser | None = None, + ) -> list[UserGroup]: + response = requests.get( + f"{API_SERVER_URL}/manage/admin/user-group", + headers=user_performing_action.headers + if user_performing_action + else GENERAL_HEADERS, + ) + response.raise_for_status() + return [UserGroup(**ug) for ug in response.json()] + + @staticmethod + def verify( + user_group: TestUserGroup, + verify_deleted: bool = False, + user_performing_action: TestUser | None = None, + ) -> None: + all_user_groups = UserGroupManager.get_all(user_performing_action) + for fetched_user_group in all_user_groups: + if user_group.id == fetched_user_group.id: + if verify_deleted: + raise ValueError( + f"User group {user_group.id} found but should be deleted" + ) + fetched_cc_ids = {cc_pair.id for cc_pair in fetched_user_group.cc_pairs} + fetched_user_ids = {user.id for user in fetched_user_group.users} + user_group_cc_ids = set(user_group.cc_pair_ids) + user_group_user_ids = set(user_group.user_ids) + if ( + fetched_cc_ids == user_group_cc_ids + and fetched_user_ids == user_group_user_ids + ): + return + if not verify_deleted: + raise ValueError(f"User group {user_group.id} not found") + + @staticmethod + def wait_for_sync( + user_groups_to_check: list[TestUserGroup] | None = None, + user_performing_action: TestUser | None = None, + ) -> None: + start = time.time() + while True: + user_groups = UserGroupManager.get_all(user_performing_action) + if user_groups_to_check: + check_ids = {user_group.id for user_group in user_groups_to_check} + user_group_ids = {user_group.id for user_group in user_groups} + if not check_ids.issubset(user_group_ids): + raise RuntimeError("Document set not found") + user_groups = [ + user_group + for user_group in user_groups + if user_group.id in check_ids + ] + if all(ug.is_up_to_date for ug in user_groups): + return + + if time.time() - start > MAX_DELAY: + raise TimeoutError( + f"User groups were not synced within the {MAX_DELAY} seconds" + ) + else: + print("User groups were not synced yet, waiting...") + time.sleep(2) diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 3815aa9f972..fd9d194d661 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -20,7 +20,6 @@ from danswer.indexing.models import IndexingSetting from danswer.main import setup_postgres from danswer.main import setup_vespa -from tests.integration.common_utils.llm import seed_default_openai_provider def _run_migrations( @@ -167,6 +166,4 @@ def reset_all() -> None: reset_postgres() print("Resetting Vespa...") reset_vespa() - print("Seeding LLM Providers...") - seed_default_openai_provider() print("Finished resetting all.") diff --git a/backend/tests/integration/common_utils/seed_documents.py b/backend/tests/integration/common_utils/seed_documents.py deleted file mode 100644 index b6720c9aebe..00000000000 --- a/backend/tests/integration/common_utils/seed_documents.py +++ /dev/null @@ -1,72 +0,0 @@ -import uuid - -import requests -from pydantic import BaseModel - -from danswer.configs.constants import DocumentSource -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import API_SERVER_URL - - -class SimpleTestDocument(BaseModel): - id: str - content: str - - -class SeedDocumentResponse(BaseModel): - cc_pair_id: int - documents: list[SimpleTestDocument] - - -class TestDocumentClient: - @staticmethod - def seed_documents( - num_docs: int = 5, cc_pair_id: int | None = None - ) -> SeedDocumentResponse: - if not cc_pair_id: - connector_details = ConnectorClient.create_connector() - cc_pair_id = connector_details.cc_pair_id - - # Create and ingest some documents - documents: list[dict] = [] - for _ in range(num_docs): - document_id = f"test-doc-{uuid.uuid4()}" - document = { - "document": { - "id": document_id, - "sections": [ - { - "text": f"This is test document {document_id}", - "link": f"{document_id}", - } - ], - "source": DocumentSource.NOT_APPLICABLE, - # just for testing metadata - "metadata": {"document_id": document_id}, - "semantic_identifier": f"Test Document {document_id}", - "from_ingestion_api": True, - }, - "cc_pair_id": cc_pair_id, - } - documents.append(document) - response = requests.post( - f"{API_SERVER_URL}/danswer-api/ingestion", - json=document, - ) - response.raise_for_status() - - print("Seeding completed successfully.") - return SeedDocumentResponse( - cc_pair_id=cc_pair_id, - documents=[ - SimpleTestDocument( - id=document["document"]["id"], - content=document["document"]["sections"][0]["text"], - ) - for document in documents - ], - ) - - -if __name__ == "__main__": - seed_documents_resp = TestDocumentClient.seed_documents() diff --git a/backend/tests/integration/common_utils/test_models.py b/backend/tests/integration/common_utils/test_models.py new file mode 100644 index 00000000000..04db0851e3d --- /dev/null +++ b/backend/tests/integration/common_utils/test_models.py @@ -0,0 +1,120 @@ +from typing import Any +from uuid import UUID + +from pydantic import BaseModel +from pydantic import Field + +from danswer.auth.schemas import UserRole +from danswer.search.enums import RecencyBiasSetting +from danswer.server.documents.models import DocumentSource +from danswer.server.documents.models import InputType + +""" +These data models are used to represent the data on the testing side of things. +This means the flow is: +1. Make request that changes data in db +2. Make a change to the testing model +3. Retrieve data from db +4. Compare db data with testing model to verify +""" + + +class TestAPIKey(BaseModel): + api_key_id: int + api_key_display: str + api_key: str | None = None # only present on initial creation + api_key_name: str | None = None + api_key_role: UserRole + + user_id: UUID + headers: dict + + +class TestUser(BaseModel): + id: str + email: str + password: str + headers: dict + + +class TestCredential(BaseModel): + id: int + name: str + credential_json: dict[str, Any] + admin_public: bool + source: DocumentSource + curator_public: bool + groups: list[int] + + +class TestConnector(BaseModel): + id: int + name: str + source: DocumentSource + input_type: InputType + connector_specific_config: dict[str, Any] + groups: list[int] | None = None + is_public: bool | None = None + + +class SimpleTestDocument(BaseModel): + id: str + content: str + + +class TestCCPair(BaseModel): + id: int + name: str + connector_id: int + credential_id: int + is_public: bool + groups: list[int] + documents: list[SimpleTestDocument] = Field(default_factory=list) + + +class TestUserGroup(BaseModel): + id: int + name: str + user_ids: list[str] + cc_pair_ids: list[int] + + +class TestLLMProvider(BaseModel): + id: int + name: str + provider: str + api_key: str + default_model_name: str + is_public: bool + groups: list[TestUserGroup] + api_base: str | None = None + api_version: str | None = None + + +class TestDocumentSet(BaseModel): + id: int + name: str + description: str + cc_pair_ids: list[int] = Field(default_factory=list) + is_public: bool + is_up_to_date: bool + users: list[str] = Field(default_factory=list) + groups: list[int] = Field(default_factory=list) + + +class TestPersona(BaseModel): + id: int + name: str + description: str + num_chunks: float + llm_relevance_filter: bool + is_public: bool + llm_filter_extraction: bool + recency_bias: RecencyBiasSetting + prompt_ids: list[int] + document_set_ids: list[int] + tool_ids: list[int] + llm_model_provider_override: str | None + llm_model_version_override: str | None + users: list[str] + groups: list[int] diff --git a/backend/tests/integration/common_utils/user_groups.py b/backend/tests/integration/common_utils/user_groups.py deleted file mode 100644 index 0cd44066463..00000000000 --- a/backend/tests/integration/common_utils/user_groups.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import cast - -import requests - -from ee.danswer.server.user_group.models import UserGroup -from ee.danswer.server.user_group.models import UserGroupCreate -from tests.integration.common_utils.constants import API_SERVER_URL - - -class UserGroupClient: - @staticmethod - def create_user_group(user_group_creation_request: UserGroupCreate) -> int: - response = requests.post( - f"{API_SERVER_URL}/manage/admin/user-group", - json=user_group_creation_request.model_dump(), - ) - response.raise_for_status() - return cast(int, response.json()["id"]) - - @staticmethod - def fetch_user_groups() -> list[UserGroup]: - response = requests.get(f"{API_SERVER_URL}/manage/admin/user-group") - response.raise_for_status() - return [UserGroup(**ug) for ug in response.json()] diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 6c46e9f875e..314b78ad36f 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -1,3 +1,4 @@ +import os from collections.abc import Generator import pytest @@ -9,6 +10,25 @@ from tests.integration.common_utils.vespa import TestVespaClient +def load_env_vars(env_file: str = ".env") -> None: + current_dir = os.path.dirname(os.path.abspath(__file__)) + env_path = os.path.join(current_dir, env_file) + try: + with open(env_path, "r") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#"): + key, value = line.split("=", 1) + os.environ[key] = value.strip() + print("Successfully loaded environment variables") + except FileNotFoundError: + print(f"File {env_file} not found") + + +# Load environment variables at the module level +load_env_vars() + + @pytest.fixture def db_session() -> Generator[Session, None, None]: with get_session_context_manager() as session: diff --git a/backend/tests/integration/tests/connector/test_deletion.py b/backend/tests/integration/tests/connector/test_deletion.py index 78ad2378af9..a7708a02418 100644 --- a/backend/tests/integration/tests/connector/test_deletion.py +++ b/backend/tests/integration/tests/connector/test_deletion.py @@ -1,190 +1,305 @@ -import time - -from danswer.db.enums import ConnectorCredentialPairStatus -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.connectors import ConnectorClient -from tests.integration.common_utils.constants import MAX_DELAY -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient -from tests.integration.common_utils.user_groups import UserGroupClient -from tests.integration.common_utils.user_groups import UserGroupCreate +""" +This file contains tests for the following: +- Ensuring deletion of a connector also: + - deletes the documents in vespa for that connector + - updates the document sets and user groups to remove the connector +- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected +""" +from uuid import uuid4 + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser +from tests.integration.common_utils.test_models import TestUserGroup from tests.integration.common_utils.vespa import TestVespaClient def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c2_details = ConnectorClient.create_connector(name_prefix="tc2") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, ) - c2_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c2_details.cc_pair_id + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + num_docs=NUM_DOCS, + api_key=api_key, ) # create document sets - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="Intially connector to be deleted, should be empty after test", - cc_pair_ids=[c1_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) - ) - - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Intially both connectors, should contain undeleted connector after test", - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + name="Test Document Set 2", + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, ) # wait for document sets to be synced - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) + DocumentSetManager.wait_for_sync(user_performing_action=admin_user) - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") + print("Document sets created and synced") - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - break + # create user groups + user_group_1: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + user_group_2: TestUserGroup = UserGroupManager.create( + cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync(user_performing_action=admin_user) - if time.time() - start > MAX_DELAY: - raise TimeoutError("Document sets were not synced within the max delay") + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) - time.sleep(2) + # Update local records to match the database for later comparison + user_group_1.cc_pair_ids = [] + user_group_2.cc_pair_ids = [cc_pair_2.id] + doc_set_1.cc_pair_ids = [] + doc_set_2.cc_pair_ids = [cc_pair_2.id] + cc_pair_1.groups = [] + cc_pair_2.groups = [user_group_2.id] - print("Document sets created and synced") + CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) - # if so, create ACLs - user_group_1 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 1", user_ids=[], cc_pair_ids=[c1_details.cc_pair_id] - ) + # validate vespa documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + verify_deleted=True, ) - user_group_2 = UserGroupClient.create_user_group( - UserGroupCreate( - name="Test User Group 2", - user_ids=[], - cc_pair_ids=[c1_details.cc_pair_id, c2_details.cc_pair_id], - ) + + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[doc_set_2.name], + group_names=[user_group_2.name], + doc_creating_user=admin_user, + verify_deleted=False, ) - # wait for user groups to be available - start = time.time() - while True: - user_groups = {ug.id: ug for ug in UserGroupClient.fetch_user_groups()} + # check that only connector 1 is deleted + CCPairManager.verify( + cc_pair=cc_pair_2, + user_performing_action=admin_user, + ) - if not ( - user_group_1 in user_groups.keys() and user_group_2 in user_groups.keys() - ): - raise RuntimeError("User groups not found") + # validate document sets + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) - if ( - user_groups[user_group_1].is_up_to_date - and user_groups[user_group_2].is_up_to_date - ): - break + # validate user groups + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.verify( + user_group=user_group_2, + user_performing_action=admin_user, + ) + + +def test_connector_deletion_for_overlapping_connectors( + reset: None, vespa_client: TestVespaClient +) -> None: + """Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping + document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors. + """ + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) - if time.time() - start > MAX_DELAY: - raise TimeoutError("User groups were not synced within the max delay") + # create connectors + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + cc_pair_2 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) - time.sleep(2) + doc_ids = [str(uuid4())] + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + document_ids=doc_ids, + api_key=api_key, + ) + cc_pair_2 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_2, + document_ids=doc_ids, + api_key=api_key, + ) - print("User groups created and synced") + # verify vespa document exists and that it is not in any document sets or groups + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_set_names=[], + group_names=[], + doc_creating_user=admin_user, + ) - # delete connector 1 - ConnectorClient.update_connector_status( - cc_pair_id=c1_details.cc_pair_id, status=ConnectorCredentialPairStatus.PAUSED + # create document set + doc_set_1 = DocumentSetManager.create( + name="Test Document Set 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, ) - ConnectorClient.delete_connector( - connector_id=c1_details.connector_id, credential_id=c1_details.credential_id + DocumentSetManager.wait_for_sync( + document_sets_to_check=[doc_set_1], + user_performing_action=admin_user, ) - start = time.time() - while True: - connectors = ConnectorClient.get_connectors() + print("Document set 1 created and synced") - if c1_details.connector_id not in [c["id"] for c in connectors]: - break + # verify vespa document is in the document set + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + doc_creating_user=admin_user, + ) - if time.time() - start > MAX_DELAY: - raise TimeoutError("Connector 1 was not deleted within the max delay") + # create a user group and attach it to connector 1 + user_group_1: TestUserGroup = UserGroupManager.create( + name="Test User Group 1", + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], + user_performing_action=admin_user, + ) + cc_pair_1.groups = [user_group_1.id] - time.sleep(2) + print("User group 1 created and synced") - print("Connector 1 deleted") + # create a user group and attach it to connector 2 + user_group_2: TestUserGroup = UserGroupManager.create( + name="Test User Group 2", + cc_pair_ids=[cc_pair_2.id], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_2], + user_performing_action=admin_user, + ) + cc_pair_2.groups = [user_group_2.id] - # validate vespa documents - c1_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c1_seed_res.documents] - )["documents"] - c2_vespa_docs = vespa_client.get_documents_by_id( - [doc.id for doc in c2_seed_res.documents] - )["documents"] - - assert len(c1_vespa_docs) == 0 - assert len(c2_vespa_docs) == 5 - - for doc in c2_vespa_docs: - assert doc["fields"]["access_control_list"] == { - "PUBLIC": 1, - "group:Test User Group 2": 1, - } - assert doc["fields"]["document_sets"] == {"Test Document Set 2": 1} + print("User group 2 created and synced") - # check that only connector 1 is deleted - # TODO: check for the CC pair rather than the connector once the refactor is done - all_connectors = ConnectorClient.get_connectors() - assert len(all_connectors) == 1 - assert all_connectors[0]["id"] == c2_details.connector_id + # verify vespa document is in the user group + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_2, + group_names=[user_group_1.name, user_group_2.name], + doc_creating_user=admin_user, + ) - # validate document sets - all_doc_sets = DocumentSetClient.fetch_document_sets() - assert len(all_doc_sets) == 2 + # delete connector 1 + CCPairManager.pause_cc_pair( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) + CCPairManager.delete( + cc_pair=cc_pair_1, + user_performing_action=admin_user, + ) - doc_set_1_found = False - doc_set_2_found = False - for doc_set in all_doc_sets: - if doc_set.id == doc_set_1_id: - doc_set_1_found = True - assert doc_set.cc_pair_descriptors == [] + # EVERYTHING BELOW HERE IS CURRENTLY BROKEN AND NEEDS TO BE FIXED SERVER SIDE - if doc_set.id == doc_set_2_id: - doc_set_2_found = True - assert len(doc_set.cc_pair_descriptors) == 1 - assert doc_set.cc_pair_descriptors[0].id == c2_details.cc_pair_id + # wait for deletion to finish + # CCPairManager.wait_for_deletion_completion(user_performing_action=admin_user) - assert doc_set_1_found - assert doc_set_2_found + # print("Connector 1 deleted") - # validate user groups - all_user_groups = UserGroupClient.fetch_user_groups() - assert len(all_user_groups) == 2 - - user_group_1_found = False - user_group_2_found = False - for user_group in all_user_groups: - if user_group.id == user_group_1: - user_group_1_found = True - assert user_group.cc_pairs == [] - if user_group.id == user_group_2: - user_group_2_found = True - assert len(user_group.cc_pairs) == 1 - assert user_group.cc_pairs[0].id == c2_details.cc_pair_id - - assert user_group_1_found - assert user_group_2_found + # check that only connector 1 is deleted + # TODO: check for the CC pair rather than the connector once the refactor is done + # CCPairManager.verify( + # cc_pair=cc_pair_1, + # verify_deleted=True, + # user_performing_action=admin_user, + # ) + # CCPairManager.verify( + # cc_pair=cc_pair_2, + # user_performing_action=admin_user, + # ) + + # verify the document is not in any document sets + # verify the document is only in user group 2 + # DocumentManager.verify( + # vespa_client=vespa_client, + # cc_pair=cc_pair_2, + # doc_set_names=[], + # group_names=[user_group_2.name], + # doc_creating_user=admin_user, + # verify_deleted=False, + # ) diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index b00c2e3d1e6..981a9cbd026 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -1,34 +1,59 @@ import requests -from tests.integration.common_utils.connectors import ConnectorClient +from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL -from tests.integration.common_utils.seed_documents import TestDocumentClient +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.llm import LLMProviderManager +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestCCPair +from tests.integration.common_utils.test_models import TestUser def test_send_message_simple_with_history(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + # create connectors - c1_details = ConnectorClient.create_connector(name_prefix="tc1") - c1_seed_res = TestDocumentClient.seed_documents( - num_docs=5, cc_pair_id=c1_details.cc_pair_id + cc_pair_1: TestCCPair = CCPairManager.create_from_scratch( + user_performing_action=admin_user, + ) + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, + ) + LLMProviderManager.create(user_performing_action=admin_user) + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, ) response = requests.post( f"{API_SERVER_URL}/chat/send-message-simple-with-history", json={ - "messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}], + "messages": [ + { + "message": cc_pair_1.documents[0].content, + "role": MessageType.USER.value, + } + ], "persona_id": 0, "prompt_id": 0, }, + headers=admin_user.headers, ) assert response.status_code == 200 response_json = response.json() # Check that the top document is the correct document - assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id + assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id # assert that the metadata is correct - for doc in c1_seed_res.documents: + for doc in cc_pair_1.documents: found_doc = next( (x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None ) diff --git a/backend/tests/integration/tests/document_set/test_syncing.py b/backend/tests/integration/tests/document_set/test_syncing.py index 9a6b42ab5df..ab31b751471 100644 --- a/backend/tests/integration/tests/document_set/test_syncing.py +++ b/backend/tests/integration/tests/document_set/test_syncing.py @@ -1,78 +1,66 @@ -import time - -from danswer.server.features.document_set.models import DocumentSetCreationRequest -from tests.integration.common_utils.document_sets import DocumentSetClient -from tests.integration.common_utils.seed_documents import TestDocumentClient +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.constants import NUM_DOCS +from tests.integration.common_utils.managers.api_key import APIKeyManager +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document import DocumentManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.test_models import TestAPIKey +from tests.integration.common_utils.test_models import TestUser from tests.integration.common_utils.vespa import TestVespaClient def test_multiple_document_sets_syncing_same_connnector( reset: None, vespa_client: TestVespaClient ) -> None: - # Seed documents - seed_result = TestDocumentClient.seed_documents(num_docs=5) - cc_pair_id = seed_result.cc_pair_id + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") - # Create first document set - doc_set_1_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 1", - description="First test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # add api key to user + api_key: TestAPIKey = APIKeyManager.create( + user_performing_action=admin_user, ) - doc_set_2_id = DocumentSetClient.create_document_set( - DocumentSetCreationRequest( - name="Test Document Set 2", - description="Second test document set", - cc_pair_ids=[cc_pair_id], - is_public=True, - users=[], - groups=[], - ) + # create connector + cc_pair_1 = CCPairManager.create_from_scratch( + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, ) - # wait for syncing to be complete - max_delay = 45 - start = time.time() - while True: - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_1 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_1_id), None - ) - doc_set_2 = next( - (doc_set for doc_set in doc_sets if doc_set.id == doc_set_2_id), None - ) - - if not doc_set_1 or not doc_set_2: - raise RuntimeError("Document set not found") - - if doc_set_1.is_up_to_date and doc_set_2.is_up_to_date: - assert [ccp.id for ccp in doc_set_1.cc_pair_descriptors] == [ - ccp.id for ccp in doc_set_2.cc_pair_descriptors - ] - break + # seed documents + cc_pair_1 = DocumentManager.seed_and_attach_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) - if time.time() - start > max_delay: - raise TimeoutError("Document sets were not synced within the max delay") + # Create document sets + doc_set_1 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) + doc_set_2 = DocumentSetManager.create( + cc_pair_ids=[cc_pair_1.id], + user_performing_action=admin_user, + ) - time.sleep(2) + DocumentSetManager.wait_for_sync( + user_performing_action=admin_user, + ) - # get names so we can compare to what is in vespa - doc_sets = DocumentSetClient.fetch_document_sets() - doc_set_names = {doc_set.name for doc_set in doc_sets} + DocumentSetManager.verify( + document_set=doc_set_1, + user_performing_action=admin_user, + ) + DocumentSetManager.verify( + document_set=doc_set_2, + user_performing_action=admin_user, + ) # make sure documents are as expected - seeded_document_ids = [doc.id for doc in seed_result.documents] - - result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents]) - documents = result["documents"] - assert len(documents) == len(seed_result.documents) - assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents) - assert all( - set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents + DocumentManager.verify( + vespa_client=vespa_client, + cc_pair=cc_pair_1, + doc_set_names=[doc_set_1.name, doc_set_2.name], + doc_creating_user=admin_user, ) diff --git a/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py new file mode 100644 index 00000000000..c52c5826eae --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_cc_pair_permissions.py @@ -0,0 +1,179 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connector-credential pairs. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_cc_pair_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Create a credentials that the curator is and is not curator of + connector_1 = ConnectorManager.create( + name="curator_owned_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=admin_user, + ) + # currently we dont enforce permissions at the connector level + # pending cc_pair -> connector rework + # connector_2 = ConnectorManager.create( + # name="curator_visible_connector", + # source=DocumentSource.CONFLUENCE, + # groups=[user_group_2.id], + # is_public=False, + # user_performing_action=admin_user, + # ) + credential_1 = CredentialManager.create( + name="curator_owned_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=admin_user, + ) + credential_2 = CredentialManager.create( + name="curator_visible_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + curator_public=False, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public cc pair + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_1", + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair for a user group they are not a curator of + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc + # pair without an attached user group + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_1.id, + name="invalid_cc_pair_2", + groups=[], + is_public=False, + user_performing_action=curator, + ) + + # # This test is currently disabled because permissions are + # # not enforced at the connector level + # # Curators should not be able to create a cc pair + # # for a user group that the connector does not belong to (NOT WORKING) + # with pytest.raises(HTTPError): + # CCPairManager.create( + # connector_id=connector_2.id, + # credential_id=credential_1.id, + # name="invalid_cc_pair_3", + # groups=[user_group_1.id], + # is_public=False, + # user_performing_action=curator, + # ) + + # Curators should not be able to create a cc + # pair for a user group that the credential does not belong to + with pytest.raises(HTTPError): + CCPairManager.create( + connector_id=connector_1.id, + credential_id=credential_2.id, + name="invalid_cc_pair_4", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # cc pair for a user group they are a curator of + valid_cc_pair = CCPairManager.create( + name="valid_cc_pair", + connector_id=connector_1.id, + credential_id=credential_1.id, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + # Verify the created cc pair + CCPairManager.verify( + cc_pair=valid_cc_pair, + user_performing_action=curator, + ) + + # Test pausing the cc pair + CCPairManager.pause_cc_pair(valid_cc_pair, user_performing_action=curator) + + # Test deleting the cc pair + CCPairManager.delete(valid_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + CCPairManager.verify( + cc_pair=valid_cc_pair, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_connector_permissions.py b/backend/tests/integration/tests/permissions/test_connector_permissions.py new file mode 100644 index 00000000000..279c0568bfb --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_connector_permissions.py @@ -0,0 +1,136 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating connectors. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_connector_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a cc pair for a + # user group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + + # Curators should be able to create a private + # connector for a user group they are a curator of + valid_connector = ConnectorManager.create( + name="valid_connector", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + assert valid_connector.id is not None + + # Verify the created connector + created_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert created_connector.name == valid_connector.name + assert created_connector.source == valid_connector.source + + # Verify that the connector can be found in the list of all connectors + all_connectors = ConnectorManager.get_all(user_performing_action=curator) + assert any(conn.id == valid_connector.id for conn in all_connectors) + + # Test editing the connector + valid_connector.name = "updated_valid_connector" + ConnectorManager.edit(valid_connector, user_performing_action=curator) + + # Verify the edit + updated_connector = ConnectorManager.get( + valid_connector.id, user_performing_action=curator + ) + assert updated_connector.name == "updated_valid_connector" + + # Test deleting the connector + ConnectorManager.delete(connector=valid_connector, user_performing_action=curator) + + # Verify the deletion + all_connectors_after_delete = ConnectorManager.get_all( + user_performing_action=curator + ) + assert all(conn.id != valid_connector.id for conn in all_connectors_after_delete) + + # Test that curator cannot create a connector for a group they are not a curator of + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_3", + source=DocumentSource.CONFLUENCE, + groups=[user_group_2.id], + is_public=False, + user_performing_action=curator, + ) + + # Test that curator cannot create a public connector + with pytest.raises(HTTPError): + ConnectorManager.create( + name="invalid_connector_4", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + is_public=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_credential_permissions.py b/backend/tests/integration/tests/permissions/test_credential_permissions.py new file mode 100644 index 00000000000..1311f1a3d2d --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_credential_permissions.py @@ -0,0 +1,108 @@ +""" +This file takes the happy path to adding a curator to a user group and then tests +the permissions of the curator manipulating credentials. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_credential_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # setting the user as a curator for the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating another user group that the user is not a curator of + user_group_2 = UserGroupManager.create( + name="user_group_2", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # END OF HAPPY PATH + + """Tests for things Curators should not be able to do""" + + # Curators should not be able to create a public credential + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_1", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=True, + user_performing_action=curator, + ) + + # Curators should not be able to create a credential for a user group they are not a curator of + with pytest.raises(HTTPError): + CredentialManager.create( + name="invalid_credential_2", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id, user_group_2.id], + curator_public=False, + user_performing_action=curator, + ) + + """Tests for things Curators should be able to do""" + # Curators should be able to create a private credential for a user group they are a curator of + valid_credential = CredentialManager.create( + name="valid_credential", + source=DocumentSource.CONFLUENCE, + groups=[user_group_1.id], + curator_public=False, + user_performing_action=curator, + ) + + # Verify the created credential + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test editing the credential + valid_credential.name = "updated_valid_credential" + CredentialManager.edit(valid_credential, user_performing_action=curator) + + # Verify the edit + CredentialManager.verify( + credential=valid_credential, + user_performing_action=curator, + ) + + # Test deleting the credential + CredentialManager.delete(valid_credential, user_performing_action=curator) + + # Verify the deletion + CredentialManager.verify( + credential=valid_credential, + verify_deleted=True, + user_performing_action=curator, + ) diff --git a/backend/tests/integration/tests/permissions/test_doc_set_permissions.py b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py new file mode 100644 index 00000000000..a2601bf4e46 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_doc_set_permissions.py @@ -0,0 +1,186 @@ +import pytest +from requests.exceptions import HTTPError + +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.document_set import DocumentSetManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_doc_set_permissions_setup(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + + # Creating a second user (curator) + curator: TestUser = UserManager.create(name="curator") + + # Creating the first user group + user_group_1 = UserGroupManager.create( + name="curated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Setting the curator as a curator for the first user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Creating a second user group + user_group_2 = UserGroupManager.create( + name="uncurated_user_group", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # Admin creates a cc_pair + private_cc_pair = CCPairManager.create_from_scratch( + is_public=False, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # Admin creates a public cc_pair + public_cc_pair = CCPairManager.create_from_scratch( + is_public=True, + source=DocumentSource.INGESTION_API, + user_performing_action=admin_user, + ) + + # END OF HAPPY PATH + + """Tests for things Curators/Admins should not be able to do""" + + # Test that curator cannot create a document set for the group they don't curate + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 1", + groups=[user_group_2.id], + cc_pair_ids=[public_cc_pair.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set attached to both groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 2", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id, user_group_2.id], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no groups + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 3", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[], + user_performing_action=curator, + ) + + # Test that curator cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test that admin cannot create a document set with no cc_pairs + with pytest.raises(HTTPError): + DocumentSetManager.create( + name="Invalid Document Set 4", + is_public=False, + cc_pair_ids=[], + groups=[user_group_1.id], + user_performing_action=admin_user, + ) + + """Tests for things Curators should be able to do""" + # Test that curator can create a document set for the group they curate + valid_doc_set = DocumentSetManager.create( + name="Valid Document Set", + is_public=False, + cc_pair_ids=[public_cc_pair.id], + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Verify that the valid document set was created + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify that only one document set exists + all_doc_sets = DocumentSetManager.get_all(user_performing_action=admin_user) + assert len(all_doc_sets) == 1 + + # Add the private_cc_pair to the doc set on our end for later comparison + valid_doc_set.cc_pair_ids.append(private_cc_pair.id) + + # Confirm the curator can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + # Confirm the admin can't add the private_cc_pair to the doc set + with pytest.raises(HTTPError): + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Verify the document set has not been updated in the db + with pytest.raises(ValueError): + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) + + # Add the private_cc_pair to the user group on our end for later comparison + user_group_1.cc_pair_ids.append(private_cc_pair.id) + + # Admin adds the cc_pair to the group the curator curates + UserGroupManager.edit( + user_group=user_group_1, + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + UserGroupManager.verify( + user_group=user_group_1, + user_performing_action=admin_user, + ) + + # Confirm the curator can now add the cc_pair to the doc set + DocumentSetManager.edit( + document_set=valid_doc_set, + user_performing_action=curator, + ) + DocumentSetManager.wait_for_sync( + document_sets_to_check=[valid_doc_set], user_performing_action=admin_user + ) + # Verify the updated document set + DocumentSetManager.verify( + document_set=valid_doc_set, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_user_role_permissions.py b/backend/tests/integration/tests/permissions/test_user_role_permissions.py new file mode 100644 index 00000000000..5da91a57af8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_user_role_permissions.py @@ -0,0 +1,93 @@ +""" +This file tests the ability of different user types to set the role of other users. +""" +import pytest +from requests.exceptions import HTTPError + +from danswer.db.models import UserRole +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_user_role_setting_permissions(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a basic user + basic_user: TestUser = UserManager.create(name="basic_user") + assert UserManager.verify_role(basic_user, UserRole.BASIC) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + assert UserManager.verify_role(curator, UserRole.BASIC) + + # Creating a curator without adding to a group should not work + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=curator, + target_role=UserRole.CURATOR, + user_to_perform_action=admin_user, + ) + + global_curator: TestUser = UserManager.create(name="global_curator") + assert UserManager.verify_role(global_curator, UserRole.BASIC) + + # Setting the role of a global curator should not work for a basic user + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=basic_user, + ) + + # Setting the role of a global curator should work for an admin user + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.GLOBAL_CURATOR, + user_to_perform_action=admin_user, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Setting the role of a global curator should not work for an invalid curator + with pytest.raises(HTTPError): + UserManager.set_role( + user_to_set=global_curator, + target_role=UserRole.BASIC, + user_to_perform_action=global_curator, + ) + assert UserManager.verify_role(global_curator, UserRole.GLOBAL_CURATOR) + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should fail because the curator is not in the user group + with pytest.raises(HTTPError): + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + + # Adding the curator to the user group + user_group_1.user_ids = [curator.id] + UserGroupManager.edit(user_group=user_group_1, user_performing_action=admin_user) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + + # This should work because the curator is in the user group + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) diff --git a/backend/tests/integration/tests/permissions/test_whole_curator_flow.py b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py new file mode 100644 index 00000000000..878ba1e17e8 --- /dev/null +++ b/backend/tests/integration/tests/permissions/test_whole_curator_flow.py @@ -0,0 +1,86 @@ +""" +This test tests the happy path for curator permissions +""" +from danswer.db.models import UserRole +from danswer.server.documents.models import DocumentSource +from tests.integration.common_utils.managers.cc_pair import CCPairManager +from tests.integration.common_utils.managers.connector import ConnectorManager +from tests.integration.common_utils.managers.credential import CredentialManager +from tests.integration.common_utils.managers.user import TestUser +from tests.integration.common_utils.managers.user import UserManager +from tests.integration.common_utils.managers.user_group import UserGroupManager + + +def test_whole_curator_flow(reset: None) -> None: + # Creating an admin user (first user created is automatically an admin) + admin_user: TestUser = UserManager.create(name="admin_user") + assert UserManager.verify_role(admin_user, UserRole.ADMIN) + + # Creating a curator + curator: TestUser = UserManager.create(name="curator") + + # Creating a user group + user_group_1 = UserGroupManager.create( + name="user_group_1", + user_ids=[curator.id], + cc_pair_ids=[], + user_performing_action=admin_user, + ) + UserGroupManager.wait_for_sync( + user_groups_to_check=[user_group_1], user_performing_action=admin_user + ) + # Making curator a curator of user_group_1 + UserGroupManager.set_curator_status( + test_user_group=user_group_1, + user_to_set_as_curator=curator, + user_performing_action=admin_user, + ) + assert UserManager.verify_role(curator, UserRole.CURATOR) + + # Creating a credential as curator + test_credential = CredentialManager.create( + name="curator_test_credential", + source=DocumentSource.FILE, + curator_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Creating a connector as curator + test_connector = ConnectorManager.create( + name="curator_test_connector", + source=DocumentSource.FILE, + is_public=False, + groups=[user_group_1.id], + user_performing_action=curator, + ) + + # Test editing the connector + test_connector.name = "updated_test_connector" + ConnectorManager.edit(connector=test_connector, user_performing_action=curator) + + # Creating a CC pair as curator + test_cc_pair = CCPairManager.create( + connector_id=test_connector.id, + credential_id=test_credential.id, + name="curator_test_cc_pair", + groups=[user_group_1.id], + is_public=False, + user_performing_action=curator, + ) + + CCPairManager.verify(cc_pair=test_cc_pair, user_performing_action=admin_user) + + # Verify that the curator can pause and unpause the CC pair + CCPairManager.pause_cc_pair(cc_pair=test_cc_pair, user_performing_action=curator) + + # Verify that the curator can delete the CC pair + CCPairManager.delete(cc_pair=test_cc_pair, user_performing_action=curator) + CCPairManager.wait_for_deletion_completion(user_performing_action=curator) + + # Verify that the CC pair has been deleted + CCPairManager.verify( + cc_pair=test_cc_pair, + verify_deleted=True, + user_performing_action=admin_user, + ) diff --git a/web/src/lib/credential.ts b/web/src/lib/credential.ts index 0552e73cc9e..03f6c6e75da 100644 --- a/web/src/lib/credential.ts +++ b/web/src/lib/credential.ts @@ -73,7 +73,7 @@ export function updateCredential(credentialId: number, newDetails: any) { ([key, value]) => key !== "name" && value !== "" ) ); - return fetch(`/api/manage/admin/credentials/${credentialId}`, { + return fetch(`/api/manage/admin/credential/${credentialId}`, { method: "PUT", headers: { "Content-Type": "application/json", @@ -86,7 +86,7 @@ export function updateCredential(credentialId: number, newDetails: any) { } export function swapCredential(newCredentialId: number, connectorId: number) { - return fetch(`/api/manage/admin/credentials/swap`, { + return fetch(`/api/manage/admin/credential/swap`, { method: "PUT", headers: { "Content-Type": "application/json",