Skip to content

Commit

Permalink
feat: add FeedbackDataset search endpoint (#3068)
Browse files Browse the repository at this point in the history
# Description

This PR introduces a new endpoint `POST
/api/v1/me/datasets/{dataset_id}/records/search` that will allow the
user to search records using some basic queries over the fields.

Closes #3067

**Type of change**

- [x] New feature (non-breaking change which adds functionality)

**How Has This Been Tested**

New unit tests have been added to test the functionality of this new
endpoint. In addition, I manually tested the endpoint using the Argilla
dolly dataset.

**Checklist**

- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: Francisco Aranda <francis@argilla.io>
  • Loading branch information
2 people authored and davidberenstein1957 committed Jun 7, 2023
1 parent 3e08baf commit 9384299
Show file tree
Hide file tree
Showing 8 changed files with 471 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ These are the section headers that we use:
- Added new status `draft` for the `Response` model.
- Added `LabelSelectionQuestionSettings` class allowing to create label selection (single-choice) questions in the API ([#3005])
- Added `MultiLabelSelectionQuestionSettings` class allowing to create multi-label selection (multi-choice) questions in the API ([#3010]).
- Added `POST /api/v1/me/datasets/{dataset_id}/records/search` endpoint ([#3068]).

### Changed

Expand Down
77 changes: 74 additions & 3 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from argilla.server.commons.telemetry import TelemetryClient, get_telemetry_client
from argilla.server.contexts import accounts, datasets
from argilla.server.database import get_db
from argilla.server.enums import ResponseStatusFilter
from argilla.server.models import User
from argilla.server.policies import DatasetPolicyV1, authorize
from argilla.server.schemas.v1.datasets import (
Expand All @@ -37,9 +38,15 @@
RecordInclude,
Records,
RecordsCreate,
ResponseStatusFilter,
SearchRecord,
SearchRecordsQuery,
SearchRecordsResult,
)
from argilla.server.search_engine import (
SearchEngine,
UserResponseStatusFilter,
get_search_engine,
)
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth

LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50
Expand Down Expand Up @@ -106,7 +113,7 @@ def list_current_user_dataset_records(
*,
db: Session = Depends(get_db),
dataset_id: UUID,
include: Optional[List[RecordInclude]] = Query([]),
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -285,6 +292,70 @@ async def create_dataset_records(
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))


@router.post(
"/me/datasets/{dataset_id}/records/search",
status_code=status.HTTP_200_OK,
response_model=SearchRecordsResult,
response_model_exclude_unset=True,
)
async def search_dataset_records(
*,
db: Session = Depends(get_db),
search_engine: SearchEngine = Depends(get_search_engine),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
query: SearchRecordsQuery,
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
):
dataset = _get_dataset(db, dataset_id)
authorize(current_user, DatasetPolicyV1.search_records(dataset))

search_engine_query = query.query
if search_engine_query.text.field and not datasets.get_field_by_name_and_dataset_id(
db, search_engine_query.text.field, dataset_id
):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Field `{search_engine_query.text.field}` not found in dataset `{dataset_id}`.",
)

user_response_status_filter = None
if response_status:
user_response_status_filter = UserResponseStatusFilter(
user=current_user,
statuses=[response_status.value],
)

search_responses = await search_engine.search(
dataset=dataset,
query=search_engine_query,
user_response_status_filter=user_response_status_filter,
limit=limit,
)

record_id_score_map = {
response.record_id: {"query_score": response.score, "search_record": None}
for response in search_responses.items
}
records = datasets.get_records_by_ids(
db=db,
dataset_id=dataset_id,
record_ids=list(record_id_score_map.keys()),
include=include,
user_id=current_user.id,
)

for record in records:
record_id_score_map[record.id]["search_record"] = SearchRecord(
record=record, query_score=record_id_score_map[record.id]["query_score"]
)

return SearchRecordsResult(items=[record["search_record"] for record in record_id_score_map.values()])


@router.put("/datasets/{dataset_id}/publish", response_model=Dataset)
async def publish_dataset(
*,
Expand Down
24 changes: 22 additions & 2 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, func
from sqlalchemy.orm import Session, contains_eager, joinedload
from sqlalchemy.orm import Session, contains_eager, joinedload, noload

from argilla.server.contexts import accounts
from argilla.server.enums import ResponseStatusFilter
from argilla.server.models import (
Dataset,
DatasetStatus,
Expand All @@ -37,7 +38,6 @@
QuestionCreate,
RecordInclude,
RecordsCreate,
ResponseStatusFilter,
)
from argilla.server.schemas.v1.records import ResponseCreate
from argilla.server.schemas.v1.responses import ResponseUpdate
Expand Down Expand Up @@ -189,6 +189,26 @@ def get_record_by_id(db: Session, record_id: UUID):
return db.get(Record, record_id)


def get_records_by_ids(
db: Session,
dataset_id: UUID,
record_ids: List[UUID],
include: List[RecordInclude] = [],
user_id: Optional[UUID] = None,
) -> List[Record]:
query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids))
if RecordInclude.responses in include:
if user_id:
query = query.outerjoin(
Response, and_(Response.record_id == Record.id, Response.user_id == user_id)
).options(contains_eager(Record.responses))
else:
query = query.options(joinedload(Record.responses))
else:
query = query.options(noload(Record.responses))
return query.all()


def list_records_by_dataset_id(
db: Session,
dataset_id: UUID,
Expand Down
22 changes: 22 additions & 0 deletions src/argilla/server/enums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from enum import Enum


class ResponseStatusFilter(str, Enum):
draft = "draft"
missing = "missing"
submitted = "submitted"
discarded = "discarded"
13 changes: 13 additions & 0 deletions src/argilla/server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ def create_question(cls, actor: User) -> bool:
def create_records(cls, actor: User) -> bool:
return actor.is_admin

@classmethod
def search_records(cls, dataset: Dataset) -> bool:
return lambda actor: (
actor.is_admin
or bool(
accounts.get_workspace_user_by_workspace_id_and_user_id(
Session.object_session(actor),
dataset.workspace_id,
actor.id,
)
)
)

@classmethod
def publish(cls, actor: User) -> bool:
return actor.is_admin
Expand Down
22 changes: 15 additions & 7 deletions src/argilla/server/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pydantic import BaseModel, PositiveInt, conlist, constr, root_validator, validator
from pydantic import Field as PydanticField

from argilla.server.search_engine import Query

try:
from typing import Annotated, Literal
except ImportError:
Expand Down Expand Up @@ -285,13 +287,6 @@ class RecordInclude(str, Enum):
responses = "responses"


class ResponseStatusFilter(str, Enum):
draft = "draft"
missing = "missing"
submitted = "submitted"
discarded = "discarded"


class Record(BaseModel):
id: UUID
fields: Dict[str, Any]
Expand Down Expand Up @@ -347,3 +342,16 @@ def check_user_id_is_unique(cls, values):

class RecordsCreate(BaseModel):
items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS)


class SearchRecordsQuery(BaseModel):
query: Query


class SearchRecord(BaseModel):
record: Record
query_score: Optional[float]


class SearchRecordsResult(BaseModel):
items: List[SearchRecord]
8 changes: 4 additions & 4 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from uuid import UUID

from opensearchpy import AsyncOpenSearch, helpers
from pydantic import BaseModel
from pydantic.utils import GetterDict

from argilla.server.enums import ResponseStatusFilter
from argilla.server.models import (
Dataset,
Field,
Expand All @@ -31,7 +32,6 @@
ResponseStatus,
User,
)
from argilla.server.schemas.v1.datasets import ResponseStatusFilter
from argilla.server.settings import settings


Expand Down Expand Up @@ -205,7 +205,7 @@ async def search(
items = []
next_page_token = None
for hit in response["hits"]["hits"]:
items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"]))
items.append(SearchResponseItem(record_id=UUID(hit["_id"]), score=hit["_score"]))
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
next_page_token = hit.get("_sort")

Expand Down Expand Up @@ -278,7 +278,7 @@ def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"


async def get_search_engine():
async def get_search_engine() -> Generator[SearchEngine, None, None]:
config = dict(
hosts=settings.elasticsearch,
verify_certs=settings.elasticsearch_ssl_verify,
Expand Down
Loading

0 comments on commit 9384299

Please sign in to comment.