diff --git a/tests/server/api/v1/test_datasets.py b/tests/server/api/v1/test_datasets.py index 01c56ca32c..7ef97781c7 100644 --- a/tests/server/api/v1/test_datasets.py +++ b/tests/server/api/v1/test_datasets.py @@ -2408,7 +2408,7 @@ def test_create_dataset_records_with_nonexistent_dataset_id(client: TestClient, assert db.query(Response).count() == 0 -def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: +def create_dataset_for_search(user: Optional[User] = None) -> Tuple[Dataset, List[Record]]: dataset = DatasetFactory.create(status=DatasetStatus.ready) TextFieldFactory.create(name="input", dataset=dataset) TextFieldFactory.create(name="output", dataset=dataset) @@ -2425,30 +2425,36 @@ def create_dataset_for_search() -> Tuple[Dataset, List[Record]]: record=records[0], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[1], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[2], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ResponseFactory.create( record=records[3], values={"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}, status=ResponseStatus.submitted, + user=user, ), ] + # Add some responses from other users + ResponseFactory.create_batch(10, record=records[0], status=ResponseStatus.submitted) return dataset, records, responses def test_search_dataset_records( client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, records, _ = create_dataset_for_search() + dataset, records, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses( items=[ @@ -2509,9 +2515,9 @@ def test_search_dataset_records( def test_search_dataset_records_including_responses( - client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, records, responses = create_dataset_for_search() + dataset, records, responses = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses( items=[ @@ -2601,7 +2607,7 @@ def test_search_dataset_records_including_responses( def test_search_dataset_records_with_response_status_filter( client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, _, _ = create_dataset_for_search() + dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2622,9 +2628,9 @@ def test_search_dataset_records_with_response_status_filter( def test_search_dataset_records_with_limit( - client: TestClient, mock_search_engine: SearchEngine, admin_auth_header: dict + client: TestClient, mock_search_engine: SearchEngine, admin: User, admin_auth_header: dict ): - dataset, _, _ = create_dataset_for_search() + dataset, _, _ = create_dataset_for_search(user=admin) mock_search_engine.search.return_value = SearchResponses(items=[]) query_json = {"text": {"q": "Hello", "field": "input"}} @@ -2644,8 +2650,8 @@ def test_search_dataset_records_with_limit( assert response.status_code == 200 -def test_search_dataset_records_as_annotator(client: TestClient, mock_search_engine: SearchEngine): - dataset, records, _ = create_dataset_for_search() +def test_search_dataset_records_as_annotator(client: TestClient, admin: User, mock_search_engine: SearchEngine): + dataset, records, _ = create_dataset_for_search(user=admin) annotator = AnnotatorFactory.create(workspaces=[dataset.workspace]) mock_search_engine.search.return_value = SearchResponses(