Skip to content

Commit

Permalink
feat: add response_status param to `GET /api/v1/datasets/{dataset_i…
Browse files Browse the repository at this point in the history
…d}/records` (#3613)

# Description

This PR adds the `response_status` to `GET
/api/v1/datasets/{dataset_id}/records` too, as previously it was just
included for `GET /api/v1/me/datasets/{dataset_id}/records` at
#3359, which was blocking
other developments related to the record listing.

Besides that, we've also unified `list_records_by_dataset_id` and
`list_records_by_dataset_id_and_user_id` into
`list_records_by_dataset_id` with the `user_id` arg being optional, so
that the response filter based on the `user_id` is just applied when
`user_id is not None`.

**Type of change**

- [X] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

- [X] Add unit tests for `list_dataset_records` using
`response_statuses` via `response_status` alias for `GET
/api/v1/datasets/{dataset_id}/records`

**Checklist**

- [ ] I added relevant documentation
- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
alvarobartt authored Aug 23, 2023
1 parent 00a7d2e commit 4bbeb21
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 33 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ These are the section headers that we use:

- Added `login` function in `argilla.client.login` to login into an Argilla server and store the credentials locally ([#3582](https://github.com/argilla-io/argilla/pull/3582)).
- Added `login` command to login into an Argilla server ([#3600](https://github.com/argilla-io/argilla/pull/3600)).
- Added `response_status` param to `GET /api/v1/datasets/{dataset_id}/records` to be able to filter by `response_status` as previously included for `GET /api/v1/me/datasets/{dataset_id}/records` ([#3613](https://github.com/argilla-io/argilla/pull/3613)).

### Changed

Expand Down
7 changes: 5 additions & 2 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def list_current_user_dataset_records(

await authorize(current_user, DatasetPolicyV1.get(dataset))

records = await datasets.list_records_by_dataset_id_and_user_id(
records = await datasets.list_records_by_dataset_id(
db,
dataset_id,
current_user.id,
Expand All @@ -138,6 +138,7 @@ async def list_dataset_records(
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -146,7 +147,9 @@ async def list_dataset_records(

await authorize(current_user, DatasetPolicyV1.list_dataset_records_with_all_responses(dataset))

records = await datasets.list_records_by_dataset_id(db, dataset_id, include=include, offset=offset, limit=limit)
records = await datasets.list_records_by_dataset_id(
db, dataset_id, include=include, response_statuses=response_statuses, offset=offset, limit=limit
)

return Records(items=records)

Expand Down
1 change: 1 addition & 0 deletions src/argilla/server/apis/v1/handlers/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING
from uuid import UUID

Expand Down
23 changes: 4 additions & 19 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,24 +267,7 @@ async def get_records_by_ids(
async def list_records_by_dataset_id(
db: "AsyncSession",
dataset_id: UUID,
include: List[RecordInclude] = [],
offset: int = 0,
limit: int = LIST_RECORDS_LIMIT,
) -> List[Record]:
query = select(Record).filter(Record.dataset_id == dataset_id)
if RecordInclude.responses in include:
query = query.options(joinedload(Record.responses))
if RecordInclude.suggestions in include:
query = query.options(joinedload(Record.suggestions))
query = query.order_by(Record.inserted_at.asc()).offset(offset).limit(limit)
result = await db.execute(query)
return result.unique().scalars().all()


async def list_records_by_dataset_id_and_user_id(
db: "AsyncSession",
dataset_id: UUID,
user_id: UUID,
user_id: Optional[UUID] = None,
include: List[RecordInclude] = [],
response_statuses: List[ResponseStatusFilter] = [],
offset: int = 0,
Expand All @@ -309,7 +292,9 @@ async def list_records_by_dataset_id_and_user_id(
.filter(Record.dataset_id == dataset_id)
.outerjoin(
Response,
and_(Response.record_id == Record.id, Response.user_id == user_id),
Response.record_id == Record.id
if user_id is None
else and_(Response.record_id == Record.id, Response.user_id == user_id),
)
)

Expand Down
85 changes: 73 additions & 12 deletions tests/unit/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from datetime import datetime
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from unittest.mock import ANY, MagicMock
from uuid import UUID, uuid4

Expand Down Expand Up @@ -599,6 +599,78 @@ async def test_list_dataset_records_with_offset_and_limit(
response_body = response.json()
assert [item["id"] for item in response_body["items"]] == [str(record_c.id)]

# Helper function to create records with responses
async def create_records_with_response(
self,
num_records: int,
dataset: Dataset,
user: User,
response_status: ResponseStatus,
response_values: Optional[dict] = None,
):
for record in await RecordFactory.create_batch(size=num_records, dataset=dataset):
await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status)

@pytest.mark.parametrize(
"response_status_filter", ["missing", "discarded", "submitted", "draft", ["submitted", "draft"]]
)
async def test_list_dataset_records_with_response_status_filter(
self,
async_client: "AsyncClient",
owner: "User",
owner_auth_header: dict,
response_status_filter: Union[str, List[str]],
):
num_responses_per_status = 10
response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}

dataset = await DatasetFactory.create()
# missing responses
await RecordFactory.create_batch(size=num_responses_per_status, dataset=dataset)
# discarded responses
await self.create_records_with_response(num_responses_per_status, dataset, owner, ResponseStatus.discarded)
# submitted responses
await self.create_records_with_response(
num_responses_per_status, dataset, owner, ResponseStatus.submitted, response_values
)
# drafted responses
await self.create_records_with_response(
num_responses_per_status, dataset, owner, ResponseStatus.draft, response_values
)

other_dataset = await DatasetFactory.create()
await RecordFactory.create_batch(size=2, dataset=other_dataset)

response_status_filter = (
[response_status_filter] if isinstance(response_status_filter, str) else response_status_filter
)
response_status_filter_url = [
f"response_status={response_status}" for response_status in response_status_filter
]

response = await async_client.get(
f"/api/v1/datasets/{dataset.id}/records?{'&'.join(response_status_filter_url)}&include=responses",
headers=owner_auth_header,
)

assert response.status_code == 200
response_json = response.json()

assert len(response_json["items"]) == (num_responses_per_status * len(response_status_filter))

if "missing" in response_status_filter:
assert (
len([record for record in response_json["items"] if len(record["responses"]) == 0])
>= num_responses_per_status
)
assert all(
[
record["responses"][0]["status"] in response_status_filter
for record in response_json["items"]
if len(record["responses"]) > 0
]
)

async def test_list_dataset_records_without_authentication(self, async_client: "AsyncClient"):
dataset = await DatasetFactory.create()

Expand Down Expand Up @@ -925,17 +997,6 @@ async def test_list_current_user_dataset_records_with_offset_and_limit(
response_body = response.json()
assert [item["id"] for item in response_body["items"]] == [str(record_c.id)]

async def create_records_with_response(
self,
num_records: int,
dataset: Dataset,
user: User,
response_status: ResponseStatus,
response_values: Optional[dict] = None,
):
for record in await RecordFactory.create_batch(size=num_records, dataset=dataset):
await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status)

@pytest.mark.parametrize("response_status_filter", ["missing", "discarded", "submitted", "draft"])
async def test_list_current_user_dataset_records_with_response_status_filter(
self, async_client: "AsyncClient", owner: "User", owner_auth_header: dict, response_status_filter: str
Expand Down

0 comments on commit 4bbeb21

Please sign in to comment.