Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: combine response status filters #3359

Merged
merged 10 commits into from
Aug 2, 2023
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ These are the section headers that we use:
- `User.workspaces` is no longer an attribute but a property, and is calling `list_user_workspaces` to list all the workspace names for a given user ID ([#3334](https://github.com/argilla-io/argilla/pull/3334))
- Renamed `FeedbackDatasetConfig` to `DatasetConfig` and export/import from YAML as default instead of JSON (just used internally on `push_to_huggingface` and `from_huggingface` methods of `FeedbackDataset`) ([#3326](https://github.com/argilla-io/argilla/pull/3326)).
- The protected metadata fields support other than textual info - existing datasets must be reindex. See [docs](https://docs.argilla.io/en/latest/getting_started/installation/configurations/database_migrations.html#elasticsearch) for more detail (Closes [#3332](https://github.com/argilla-io/argilla/issues/3332)).
- Updated `GET /api/v1/me/datasets/{dataset_id}/records` endpoint to allow getting records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)).
- Updated `POST /api/v1/me/datasets/{dataset_id}/records` endpoint to allow searching records matching one of the response statuses provided via query param.([#3359](https://github.com/argilla-io/argilla/pull/3359)).
- Updated `SearchEngine.search` method to allow searching records matching one of the response statuses provided ([#3359](https://github.com/argilla-io/argilla/pull/3359)).

### Fixed

Expand Down
18 changes: 12 additions & 6 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
Expand Down Expand Up @@ -119,7 +119,7 @@ async def list_current_user_dataset_records(
db: AsyncSession = Depends(get_async_db),
dataset_id: UUID,
include: List[RecordInclude] = Query([], description="Relationships to include in the response"),
response_status: Optional[ResponseStatusFilter] = Query(None),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -129,7 +129,13 @@ async def list_current_user_dataset_records(
await authorize(current_user, DatasetPolicyV1.get(dataset))

records = await datasets.list_records_by_dataset_id_and_user_id(
db, dataset_id, current_user.id, include=include, response_status=response_status, offset=offset, limit=limit
db,
dataset_id,
current_user.id,
include=include,
response_statuses=response_statuses,
offset=offset,
limit=limit,
)

return Records(items=records)
Expand Down Expand Up @@ -314,7 +320,7 @@ async def search_dataset_records(
dataset_id: UUID,
query: SearchRecordsQuery,
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
response_statuses: List[ResponseStatusFilter] = Query([], alias="response_status"),
offset: int = Query(0, ge=0),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
Expand All @@ -332,8 +338,8 @@ async def search_dataset_records(
)

user_response_status_filter = None
if response_status:
user_response_status_filter = UserResponseStatusFilter(user=current_user, status=response_status)
if response_statuses:
user_response_status_filter = UserResponseStatusFilter(user=current_user, statuses=response_statuses)

search_responses = await search_engine.search(
dataset=dataset,
Expand Down
29 changes: 22 additions & 7 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from uuid import UUID

from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, func, select
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import contains_eager, joinedload, selectinload

from argilla.server.contexts import accounts
Expand Down Expand Up @@ -267,20 +267,35 @@ async def list_records_by_dataset_id_and_user_id(
dataset_id: UUID,
user_id: UUID,
include: List[RecordInclude] = [],
response_status: Optional[ResponseStatusFilter] = None,
response_statuses: List[ResponseStatusFilter] = [],
offset: int = 0,
limit: int = LIST_RECORDS_LIMIT,
) -> List[Record]:
response_statuses_ = [
ResponseStatus(response_status)
for response_status in response_statuses
if response_status != ResponseStatusFilter.missing
]

response_status_filter_expressions = []

if response_statuses_:
response_status_filter_expressions.append(Response.status.in_(response_statuses_))

if ResponseStatusFilter.missing in response_statuses:
response_status_filter_expressions.append(Response.status.is_(None))

query = (
select(Record)
.filter(Record.dataset_id == dataset_id)
.outerjoin(Response, and_(Response.record_id == Record.id, Response.user_id == user_id))
.outerjoin(
Response,
and_(Response.record_id == Record.id, Response.user_id == user_id),
)
)

if response_status == ResponseStatusFilter.missing:
query = query.filter(Response.status == None) # noqa: E711
elif response_status is not None:
query = query.filter(Response.status == ResponseStatus(response_status))
if response_status_filter_expressions:
query = query.filter(or_(*response_status_filter_expressions))

if RecordInclude.responses in include:
query = query.options(contains_eager(Record.responses))
Expand Down
34 changes: 21 additions & 13 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ class Query:
@dataclasses.dataclass
class UserResponseStatusFilter:
user: User
status: ResponseStatusFilter
statuses: List[ResponseStatusFilter]


@dataclasses.dataclass
Expand Down Expand Up @@ -185,15 +185,11 @@ async def search(

text_query = self._text_query_builder(dataset, text=query.text)

bool_query: dict = {"must": [text_query]}
bool_query = {"must": [text_query]}
if user_response_status_filter:
bool_query["filter"] = self._response_status_filter_builder(user_response_status_filter)

body = {
"_source": False,
"query": {"bool": bool_query},
# "sort": [{"_score": "desc"}, {"id": "asc"}],
}
body = {"_source": False, "query": {"bool": bool_query}}

response = await self.client.search(
index=self._index_name_for_dataset(dataset),
Expand Down Expand Up @@ -230,7 +226,7 @@ def _mapping_for_fields(self, fields: List[Field]):
def _dynamic_templates_for_question_responses(self, questions: List[Question]) -> List[dict]:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/dynamic-templates.html
return [
{"status_responses": {"path_match": f"responses.*.status", "mapping": {"type": "keyword"}}},
{"status_responses": {"path_match": "responses.*.status", "mapping": {"type": "keyword"}}},
*[
{
f"{question.name}_responses": {
Expand Down Expand Up @@ -280,15 +276,27 @@ async def _get_index_or_raise(self, dataset: Dataset):
def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"

def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter):
def _response_status_filter_builder(self, status_filter: UserResponseStatusFilter) -> Optional[Dict[str, Any]]:

if not status_filter.statuses:
return None

user_response_field = f"responses.{status_filter.user.username}"

if status_filter.status == ResponseStatusFilter.missing:
statuses = [
ResponseStatus(status).value for status in status_filter.statuses if status != ResponseStatusFilter.missing
]

filters = []
if ResponseStatusFilter.missing in status_filter.statuses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-exists-query.html
return [{"bool": {"must_not": {"exists": {"field": user_response_field}}}}]
else:
filters.append({"bool": {"must_not": {"exists": {"field": user_response_field}}}})

if statuses:
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-terms-query.html
return [{"term": {f"{user_response_field}.status": status_filter.status}}]
filters.append({"terms": {f"{user_response_field}.status": statuses}})

return {"bool": {"should": filters, "minimum_should_match": 1}}


async def get_search_engine() -> AsyncGenerator[SearchEngine, None]:
Expand Down
7 changes: 4 additions & 3 deletions tests/client/sdk/users/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ def test_whoami_errors() -> None:
whoami(AuthenticatedClient(base_url="http://localhost:6900", token="wrong-apikey"))


def test_list_users(owner: "ServerUser") -> None:
UserFactory.create(username="user_1")
UserFactory.create(username="user_2")
@pytest.mark.asyncio
async def test_list_users(owner: "ServerUser") -> None:
await UserFactory.create(username="user_1")
await UserFactory.create(username="user_2")
httpx_client = ArgillaSingleton.init(api_key=owner.api_key).http_client.httpx

response = list_users(client=httpx_client)
Expand Down
32 changes: 19 additions & 13 deletions tests/server/api/v1/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest
from argilla._constants import API_KEY_HEADER_NAME
from argilla.server.apis.v1.handlers.datasets import LIST_DATASET_RECORDS_LIMIT_DEFAULT
from argilla.server.enums import ResponseStatusFilter
from argilla.server.models import (
Dataset,
DatasetStatus,
Expand Down Expand Up @@ -934,10 +935,21 @@ async def create_records_with_response(
await ResponseFactory.create(record=record, user=user, values=response_values, status=response_status)


@pytest.mark.parametrize("response_status_filter", ["missing", "discarded", "submitted", "draft"])
@pytest.mark.parametrize(
"response_status_filters",
[
[ResponseStatusFilter.missing],
[ResponseStatusFilter.draft],
[ResponseStatusFilter.submitted],
[ResponseStatusFilter.discarded],
[ResponseStatusFilter.missing, ResponseStatusFilter.draft],
[ResponseStatusFilter.submitted, ResponseStatusFilter.discarded],
[ResponseStatusFilter.missing, ResponseStatusFilter.draft, ResponseStatusFilter.discarded],
],
)
@pytest.mark.asyncio
async def test_list_current_user_dataset_records_with_response_status_filter(
client: TestClient, owner: "User", owner_auth_header: dict, response_status_filter: str
client: TestClient, owner: "User", owner_auth_header: dict, response_status_filters: List[ResponseStatusFilter]
):
num_responses_per_status = 10
response_values = {"input_ok": {"value": "yes"}, "output_ok": {"value": "yes"}}
Expand All @@ -957,20 +969,14 @@ async def test_list_current_user_dataset_records_with_response_status_filter(
other_dataset = await DatasetFactory.create()
await RecordFactory.create_batch(size=2, dataset=other_dataset)

response = client.get(
f"/api/v1/me/datasets/{dataset.id}/records?response_status={response_status_filter}&include=responses",
headers=owner_auth_header,
)
params = [("include", RecordInclude.responses.value)]
params.extend(("response_status", status_filter.value) for status_filter in response_status_filters)
response = client.get(f"/api/v1/me/datasets/{dataset.id}/records", headers=owner_auth_header, params=params)

assert response.status_code == 200
response_json = response.json()

assert len(response_json["items"]) == num_responses_per_status

if response_status_filter == "missing":
assert all([len(record["responses"]) == 0 for record in response_json["items"]])
else:
assert all([record["responses"][0]["status"] == response_status_filter for record in response_json["items"]])
assert len(response_json["items"]) == num_responses_per_status * len(response_status_filters)


@pytest.mark.asyncio
Expand Down Expand Up @@ -3005,7 +3011,7 @@ async def test_search_dataset_records_with_response_status_filter(
mock_search_engine.search.assert_called_once_with(
dataset=dataset,
query=Query(text=TextQuery(q="Hello", field="input")),
user_response_status_filter=UserResponseStatusFilter(user=owner, status=ResponseStatus.submitted),
user_response_status_filter=UserResponseStatusFilter(user=owner, statuses=[ResponseStatus.submitted]),
offset=0,
limit=LIST_DATASET_RECORDS_LIMIT_DEFAULT,
)
Expand Down
Loading
Loading