From fb92ca2da702a4233c81bd5abc141c4f673555f5 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Thu, 7 Sep 2023 15:33:37 +0200 Subject: [PATCH] feat: add `workspace_id` param to `GET /api/v1/me/datasets` (#3727) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description This PR adds the `workspace_id` param to `GET /api/v1/me/datasets` so that the workspace filtering when listing `FeedbackTask` datasets is applied in the API-side, as well as making sure that no local filters are applied e.g. `FeedbackDataset.list(workspace=...)` Closes #3726 **Type of change** - [X] New feature (non-breaking change which adds functionality) - [X] Improvement (change adding some improvement to an existing functionality) **How Has This Been Tested** - [x] Add tests for `GET /api/v1/me/datasets` using the `workspace_id` param, including also the updated policies for non-owner users - [x] Add tests for `list_datasets` in the Python SDK using the `workspace_id` arg - [x] Add tests for `FeedbackDataset.list` in the Python client using the `workspace` arg **Checklist** - [X] I added relevant documentation - [X] follows the style guidelines of this project - [X] I did a self-review of my code - [ ] I made corresponding changes to the documentation - [X] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK) (see text above) - [x] I have added relevant notes to the CHANGELOG.md file (See https://keepachangelog.com/) --------- Co-authored-by: Gabriel Martín Blázquez --- CHANGELOG.md | 2 ++ src/argilla/client/feedback/dataset/mixins.py | 10 +++--- src/argilla/client/sdk/v1/datasets/api.py | 17 +++++++--- .../server/apis/v1/handlers/datasets.py | 26 ++++++++------ src/argilla/server/policies.py | 11 ++++-- .../client/sdk/v1/test_datasets.py | 34 ++++++++++++++++++- tests/unit/server/api/v1/test_datasets.py | 31 +++++++++++++++-- 7 files changed, 106 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ef238de13..22a4ba16e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ These are the section headers that we use: - Added `created_at` and `updated_at` properties to `RemoteFeedbackDataset` and `FilteredRemoteFeedbackDataset` ([#3709](https://github.com/argilla-io/argilla/pull/3709)). - Added handling `PermissionError` when executing a command with a logged in user with not enough permissions ([#3717](https://github.com/argilla-io/argilla/pull/3717)). - Added `workspaces add-user` command to add a user to workspace ([#3712](https://github.com/argilla-io/argilla/pull/3712)). +- Added `workspace_id` param to `GET /api/v1/me/datasets` endpoint ([#3727](https://github.com/argilla-io/argilla/pull/3727)). +- Added `workspace_id` arg to `list_datasets` in the Python SDK ([#3727](https://github.com/argilla-io/argilla/pull/3727)). ### Changed diff --git a/src/argilla/client/feedback/dataset/mixins.py b/src/argilla/client/feedback/dataset/mixins.py index 25f307c27e..8a69a0ff86 100644 --- a/src/argilla/client/feedback/dataset/mixins.py +++ b/src/argilla/client/feedback/dataset/mixins.py @@ -314,14 +314,17 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ if workspace is not None: workspace = Workspace.from_name(workspace) - # TODO(alvarobartt or gabrielmbmb): add `workspace_id` in `GET /api/v1/datasets` - # and in `GET /api/v1/me/datasets` to filter by workspace try: - datasets = datasets_api_v1.list_datasets(client=httpx_client).parsed + datasets = datasets_api_v1.list_datasets( + client=httpx_client, workspace_id=workspace.id if workspace is not None else None + ).parsed except Exception as e: raise RuntimeError( f"Failed while listing the `FeedbackDataset` datasets in Argilla with exception: {e}" ) from e + + if len(datasets) == 0: + return [] return [ RemoteFeedbackDataset( client=httpx_client, @@ -335,5 +338,4 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[ guidelines=dataset.guidelines or None, ) for dataset in datasets - if workspace is None or dataset.workspace_id == workspace.id ] diff --git a/src/argilla/client/sdk/v1/datasets/api.py b/src/argilla/client/sdk/v1/datasets/api.py index 72fd5aa25e..e17a482ec3 100644 --- a/src/argilla/client/sdk/v1/datasets/api.py +++ b/src/argilla/client/sdk/v1/datasets/api.py @@ -132,19 +132,28 @@ def publish_dataset( def list_datasets( client: httpx.Client, -) -> Response[Union[List[FeedbackDatasetModel], ErrorMessage, HTTPValidationError]]: - """Sends a GET request to `/api/v1/datasets` endpoint to retrieve a list of `FeedbackTask` datasets. + workspace_id: Optional[UUID] = None, +) -> Response[Union[list, List[FeedbackDatasetModel], ErrorMessage, HTTPValidationError]]: + """Sends a GET request to `/api/v1/me/datasets` endpoint to retrieve a list of + `FeedbackTask` datasets filtered by `workspace_id` if applicable. Args: client: the authenticated Argilla client to be used to send the request to the API. + workspace_id: the id of the workspace to filter the datasets by. Note that the user + should either be owner or have access to the workspace. Defaults to None. Returns: A `Response` object containing a `parsed` attribute with the parsed response if the - request was successful, which is a list of `FeedbackDatasetModel`. + request was successful, which is a list of `FeedbackDatasetModel` if any, otherwise + it will contain an empty list. """ url = "/api/v1/me/datasets" - response = client.get(url=url) + params = {} + if workspace_id is not None: + params["workspace_id"] = str(workspace_id) + + response = client.get(url=url, params=params) if response.status_code == 200: response_obj = Response.from_httpx_response(response) diff --git a/src/argilla/server/apis/v1/handlers/datasets.py b/src/argilla/server/apis/v1/handlers/datasets.py index c6dff3818a..608ba2fb01 100644 --- a/src/argilla/server/apis/v1/handlers/datasets.py +++ b/src/argilla/server/apis/v1/handlers/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional from uuid import UUID from fastapi import APIRouter, Depends, HTTPException, Query, Security, status @@ -69,16 +69,22 @@ async def _get_dataset( @router.get("/me/datasets", response_model=Datasets) async def list_current_user_datasets( - *, db: AsyncSession = Depends(get_async_db), current_user: User = Security(auth.get_current_user) + *, + db: AsyncSession = Depends(get_async_db), + workspace_id: Optional[UUID] = None, + current_user: User = Security(auth.get_current_user), ): - await authorize(current_user, DatasetPolicyV1.list) - - if current_user.is_owner: - dataset_list = await datasets.list_datasets(db) - return Datasets(items=dataset_list) - - await current_user.awaitable_attrs.datasets - return Datasets(items=current_user.datasets) + await authorize(current_user, DatasetPolicyV1.list(workspace_id)) + + if not workspace_id: + if current_user.is_owner: + dataset_list = await datasets.list_datasets(db) + else: + await current_user.awaitable_attrs.datasets + dataset_list = current_user.datasets + else: + dataset_list = await datasets.list_datasets_by_workspace_id(db, workspace_id) + return Datasets(items=dataset_list) @router.get("/datasets/{dataset_id}/fields", response_model=Fields) diff --git a/src/argilla/server/policies.py b/src/argilla/server/policies.py index 7f27788bbb..923df2ea63 100644 --- a/src/argilla/server/policies.py +++ b/src/argilla/server/policies.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Optional from uuid import UUID from sqlalchemy.ext.asyncio import async_object_session @@ -210,8 +210,13 @@ async def is_allowed(actor: User) -> bool: class DatasetPolicyV1: @classmethod - async def list(cls, actor: User) -> bool: - return True + def list(cls, workspace_id: Optional[UUID] = None) -> PolicyAction: + async def is_allowed(actor: User) -> bool: + if actor.is_owner or workspace_id is None: + return True + return await _exists_workspace_user_by_user_and_workspace_id(actor, workspace_id) + + return is_allowed @classmethod def get(cls, dataset: Dataset) -> PolicyAction: diff --git a/tests/integration/client/sdk/v1/test_datasets.py b/tests/integration/client/sdk/v1/test_datasets.py index a712bc321f..b2abde784b 100644 --- a/tests/integration/client/sdk/v1/test_datasets.py +++ b/tests/integration/client/sdk/v1/test_datasets.py @@ -50,7 +50,7 @@ ) -@pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator]) +@pytest.mark.parametrize("role", [UserRole.owner, UserRole.admin, UserRole.annotator]) @pytest.mark.asyncio async def test_list_datasets(role: UserRole) -> None: dataset = await DatasetFactory.create() @@ -65,6 +65,38 @@ async def test_list_datasets(role: UserRole) -> None: assert isinstance(response.parsed[0], FeedbackDatasetModel) +@pytest.mark.parametrize( + "role, with_workspace, expected_length", + [ + (UserRole.owner, False, 2), + (UserRole.owner, True, 1), + (UserRole.admin, False, 0), + (UserRole.admin, True, 1), + (UserRole.annotator, False, 0), + (UserRole.annotator, True, 1), + ], +) +@pytest.mark.asyncio +async def test_list_datasets_by_workspace_id(role: UserRole, with_workspace: bool, expected_length: int) -> None: + workspace = await WorkspaceFactory.create() + dataset = await DatasetFactory.create(workspace=workspace) + user = await UserFactory.create(role=role, workspaces=[dataset.workspace] if with_workspace else []) + + another_workspace = await WorkspaceFactory.create() + await DatasetFactory.create(workspace=another_workspace) + + api = Argilla(api_key=user.api_key) + response = list_datasets( + client=api.client.httpx, workspace_id=str(dataset.workspace.id) if with_workspace else None + ) + + assert response.status_code == 200 + assert isinstance(response.parsed, list) + assert len(response.parsed) == expected_length + if expected_length > 0: + assert isinstance(response.parsed[0], FeedbackDatasetModel) + + @pytest.mark.parametrize("role", [UserRole.admin, UserRole.owner, UserRole.annotator]) @pytest.mark.asyncio async def test_get_datasets(role: UserRole) -> None: diff --git a/tests/unit/server/api/v1/test_datasets.py b/tests/unit/server/api/v1/test_datasets.py index d36252ee4b..97fb8841ea 100644 --- a/tests/unit/server/api/v1/test_datasets.py +++ b/tests/unit/server/api/v1/test_datasets.py @@ -85,7 +85,7 @@ @pytest.mark.asyncio class TestSuiteDatasets: - async def test_list_current_user_datasets(self, async_client: "AsyncClient", owner_auth_header: dict): + async def test_list_current_user_datasets(self, async_client: "AsyncClient", owner_auth_header: dict) -> None: dataset_a = await DatasetFactory.create(name="dataset-a") dataset_b = await DatasetFactory.create(name="dataset-b", guidelines="guidelines") dataset_c = await DatasetFactory.create(name="dataset-c", status=DatasetStatus.ready) @@ -125,7 +125,7 @@ async def test_list_current_user_datasets(self, async_client: "AsyncClient", own ] } - async def test_list_current_user_datasets_without_authentication(self, async_client: "AsyncClient"): + async def test_list_current_user_datasets_without_authentication(self, async_client: "AsyncClient") -> None: response = await async_client.get("/api/v1/me/datasets") assert response.status_code == 401 @@ -133,7 +133,7 @@ async def test_list_current_user_datasets_without_authentication(self, async_cli @pytest.mark.parametrize("role", [UserRole.annotator, UserRole.admin]) async def test_list_current_user_datasets_as_restricted_user_role( self, async_client: "AsyncClient", role: UserRole - ): + ) -> None: workspace = await WorkspaceFactory.create() user = await UserFactory.create(workspaces=[workspace], role=role) @@ -148,6 +148,31 @@ async def test_list_current_user_datasets_as_restricted_user_role( response_body = response.json() assert [dataset["name"] for dataset in response_body["items"]] == ["dataset-a", "dataset-b"] + @pytest.mark.parametrize("role", [UserRole.owner, UserRole.annotator, UserRole.admin]) + async def test_list_current_user_datasets_by_workspace_id( + self, async_client: "AsyncClient", role: UserRole + ) -> None: + workspace = await WorkspaceFactory.create() + another_workspace = await WorkspaceFactory.create() + + user = ( + await UserFactory.create(role=role) + if role == UserRole.owner + else await UserFactory.create(workspaces=[workspace], role=role) + ) + + await DatasetFactory.create(name="dataset-a", workspace=workspace) + await DatasetFactory.create(name="dataset-b", workspace=another_workspace) + + response = await async_client.get( + "/api/v1/me/datasets", params={"workspace_id": workspace.id}, headers={API_KEY_HEADER_NAME: user.api_key} + ) + + assert response.status_code == 200 + + response_body = response.json() + assert [dataset["name"] for dataset in response_body["items"]] == ["dataset-a"] + async def test_list_dataset_fields(self, async_client: "AsyncClient", owner_auth_header: dict): dataset = await DatasetFactory.create() text_field_a = await TextFieldFactory.create(