Skip to content

Commit

Permalink
feat: combine response status filters (#3359)
Browse files Browse the repository at this point in the history
# Description

This PR updates the `GET /api/v1/me/datasets/{dataset_id}/records` and
`POST /api/v1/me/datasets/{dataset_id}/records/search` and the
`SearchEngine.search` method, so records can be filtered using more than
one response status
(`/api/v1/...?response_status=submitted&response_status=discarded`).

Closes #3259

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

Locally and the unit tests for the endpoints and `SearchEngine` have
been updated to cover the cases in which more than one value for the
response status filter is provided.

**Checklist**

- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Paco Aranda <francis@argilla.io>
  • Loading branch information
2 people authored and keithCuniah committed Aug 3, 2023
1 parent 0188cbe commit 2647ed8
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 79 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ These are the section headers that we use:
- Limit rating questions values to the positive range [1, 10] (Closes [#3451](https://github.com/argilla-io/argilla/issues/3451)).
- Updated `POST /api/users` endpoint to be able to provide a list of workspace names to which the user should be linked to ([#3462](https://github.com/argilla-io/argilla/pull/3462)).
- Updated Python client `User.create` method to be able to provide a list of workspace names to which the user should be linked to ([#3462](https://github.com/argilla-io/argilla/pull/3462)).
- Updated `GET /api/v1/me/datasets/{dataset_id}/records` endpoint to allow getting records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)).
- Updated `POST /api/v1/me/datasets/{dataset_id}/records` endpoint to allow searching records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)).
- Updated `SearchEngine.search` method to allow searching records matching one of the response statuses provided ([#3359](https://github.com/argilla-io/argilla/pull/3359)).

## [1.13.3](https://github.com/argilla-io/argilla/compare/v1.13.2...v1.13.3)

Expand Down
18 changes: 12 additions & 6 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
Expand Down Expand Up @@ -108,7 +108,7 @@ async def list_current_user_dataset_records(
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
response_status: Optional[ResponseStatusFilter] = Query(None),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -118,7 +118,13 @@ async def list_current_user_dataset_records(
await authorize(current_user, DatasetPolicyV1.get(dataset))

records = await datasets.list_records_by_dataset_id_and_user_id(
db, dataset_id, current_user.id, include=include, response_status=response_status, offset=offset, limit=limit
db,
dataset_id,
current_user.id,
include=include,
response_statuses=response_statuses,
offset=offset,
limit=limit,
)

return Records(items=records)
Expand Down Expand Up @@ -297,7 +303,7 @@ async def search_dataset_records(
dataset_id: UUID,
query: SearchRecordsQuery,
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -315,8 +321,8 @@ async def search_dataset_records(
)

user_response_status_filter = None
if response_status:
user_response_status_filter = UserResponseStatusFilter(user=current_user, status=response_status)
if response_statuses:
user_response_status_filter = UserResponseStatusFilter(user=current_user, statuses=response_statuses)

search_responses = await search_engine.search(
dataset=dataset,
Expand Down
29 changes: 22 additions & 7 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from uuid import UUID

from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, delete, func, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla.server.contexts import accounts
Expand Down Expand Up @@ -267,20 +267,35 @@ async def list_records_by_dataset_id_and_user_id(
dataset_id: UUID,
user_id: UUID,
include: List[RecordInclude] = [],
response_status: Optional[ResponseStatusFilter] = None,
response_statuses: List[ResponseStatusFilter] = [],
offset: int = 0,
limit: int = LIST_RECORDS_LIMIT,
) -> List[Record]:
response_statuses_ = [
ResponseStatus(response_status)
for response_status in response_statuses
if response_status != ResponseStatusFilter.missing
]

response_status_filter_expressions = []

if response_statuses_:
response_status_filter_expressions.append(Response.status.in_(response_statuses_))

if ResponseStatusFilter.missing in response_statuses:
response_status_filter_expressions.append(Response.status.is_(None))

query = (
select(Record)
.filter(Record.dataset_id == dataset_id)
.outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id))
.outerjoin(
Response,
and_(Response.record_id == Record.id, Response.user_id == user_id),
)
)

if response_status == ResponseStatusFilter.missing:
query = query.filter(Response.status == None) # noqa: E711
elif response_status is not None:
query = query.filter(Response.status == ResponseStatus(response_status))
if response_status_filter_expressions:
query = query.filter(or_(*response_status_filter_expressions))

if RecordInclude.responses in include:
query = query.options(contains_eager(Record.responses))
Expand Down
33 changes: 20 additions & 13 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Query:
@dataclasses.dataclass
class UserResponseStatusFilter:
user: User
status: ResponseStatusFilter
statuses: List[ResponseStatusFilter]


@dataclasses.dataclass
Expand Down Expand Up @@ -190,15 +190,11 @@ async def search(

text_query = self._text_query_builder(dataset, text=query.text)

bool_query: dict = {"must": [text_query]}
bool_query = {"must": [text_query]}
if user_response_status_filter:
bool_query["filter"] = self._response_status_filter_builder(user_response_status_filter)

body = {
"_source": False,
"query": {"bool": bool_query},
# "sort": [{"_score": "desc"}, {"id": "asc"}],
}
body = {"_source": False, "query": {"bool": bool_query}}

response = await self.client.search(
index=self._index_name_for_dataset(dataset),
Expand Down Expand Up @@ -235,7 +231,7 @@ def _mapping_for_fields(self, fields: List[Field]):
def _dynamic_templates_for_question_responses(self, questions: List[Question]) -> List[dict]:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html
return [
{"status_responses": {"path_match": f"responses.*.status", "mapping": {"type": "keyword"}}},
{"status_responses": {"path_match": "responses.*.status", "mapping": {"type": "keyword"}}},
*[
{
f"{question.name}_responses": {
Expand Down Expand Up @@ -285,15 +281,26 @@ async def _get_index_or_raise(self, dataset: Dataset):
def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"

def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter):
def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Optional[Dict[str, Any]]:
if not status_filter.statuses:
return None

user_response_field = f"responses.{status_filter.user.username}"

if status_filter.status == ResponseStatusFilter.missing:
statuses = [
ResponseStatus(status).value for status in status_filter.statuses if status != ResponseStatusFilter.missing
]

filters = []
if ResponseStatusFilter.missing in status_filter.statuses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html
return [{"bool": {"must_not": {"exists": {"field": user_response_field}}}}]
else:
filters.append({"bool": {"must_not": {"exists": {"field": user_response_field}}}})

if statuses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
return [{"term": {f"{user_response_field}.status": status_filter.status}}]
filters.append({"terms": {f"{user_response_field}.status": statuses}})

return {"bool": {"should": filters, "minimum_should_match": 1}}

async def _bulk_op(self, actions: List[Dict[str, Any]]):
_, errors = await helpers.async_bulk(client=self.client, actions=actions, raise_on_error=False)
Expand Down
7 changes: 4 additions & 3 deletions tests/integration/client/sdk/api/test_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_whoami_errors() -> None:
whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey"))


def test_list_users(owner: "ServerUser") -> None:
UserFactory.create(username="user_1")
UserFactory.create(username="user_2")
@pytest.mark.asyncio
async def test_list_users(owner: "ServerUser") -> None:
await UserFactory.create(username="user_1")
await UserFactory.create(username="user_2")
httpx_client = ArgillaSingleton.init(api_key=owner.api_key).http_client.httpx

response = list_users(client=httpx_client)
Expand Down
32 changes: 19 additions & 13 deletions tests/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT
from argilla.server.enums import ResponseStatusFilter
from argilla.server.models import (
Dataset,
DatasetStatus,
Expand Down Expand Up @@ -937,10 +938,21 @@ async def create_records_with_response(
await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status)


@pytest.mark.parametrize("response_status_filter", ["missing", "discarded", "submitted", "draft"])
@pytest.mark.parametrize(
"response_status_filters",
[
[ResponseStatusFilter.missing],
[ResponseStatusFilter.draft],
[ResponseStatusFilter.submitted],
[ResponseStatusFilter.discarded],
[ResponseStatusFilter.missing, ResponseStatusFilter.draft],
[ResponseStatusFilter.submitted, ResponseStatusFilter.discarded],
[ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded],
],
)
@pytest.mark.asyncio
async def test_list_current_user_dataset_records_with_response_status_filter(
client: TestClient, owner: "User", owner_auth_header: dict, response_status_filter: str
client: TestClient, owner: "User", owner_auth_header: dict, response_status_filters: List[ResponseStatusFilter]
):
num_responses_per_status = 10
response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}
Expand All @@ -960,20 +972,14 @@ async def test_list_current_user_dataset_records_with_response_status_filter(
other_dataset = await DatasetFactory.create()
await RecordFactory.create_batch(size=2, dataset=other_dataset)

response = client.get(
f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses",
headers=owner_auth_header,
)
params = [("include", RecordInclude.responses.value)]
params.extend(("response_status", status_filter.value) for status_filter in response_status_filters)
response = client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params=params)

assert response.status_code == 200
response_json = response.json()

assert len(response_json["items"]) == num_responses_per_status

if response_status_filter == "missing":
assert all([len(record["responses"]) == 0 for record in response_json["items"]])
else:
assert all([record["responses"][0]["status"] == response_status_filter for record in response_json["items"]])
assert len(response_json["items"]) == num_responses_per_status * len(response_status_filters)


@pytest.mark.asyncio
Expand Down Expand Up @@ -3019,7 +3025,7 @@ async def test_search_dataset_records_with_response_status_filter(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=Query(text=TextQuery(q="Hello", field="input")),
user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatus.submitted),
user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatus.submitted]),
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,7 +3009,7 @@ async def test_search_dataset_records_with_response_status_filter(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=Query(text=TextQuery(q="Hello", field="input")),
user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatusFilter.submitted),
user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatusFilter.submitted]),
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
)
Expand Down
Loading

0 comments on commit 2647ed8

Please sign in to comment.