Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add FeedbackDataset search endpoint #3068

Merged
merged 23 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8bb8756
feat: add `search_records` policy method
gabrielmbmb Jun 1, 2023
c5e3731
feat: add `get_search_engine` return type hint
gabrielmbmb Jun 1, 2023
673b41a
feat: add search records endpoint
gabrielmbmb Jun 1, 2023
4450a3b
feat: update `search_records` policy
gabrielmbmb Jun 1, 2023
da873d4
Merge branch 'develop' into feature/search-endpoint
gabrielmbmb Jun 1, 2023
92df1f1
Merge branch 'develop' into feature/search-endpoint
gabrielmbmb Jun 2, 2023
a6356e8
Merge branch 'develop' into feature/search-endpoint
gabrielmbmb Jun 2, 2023
60751a1
Merge branch 'develop' into feature/search-endpoint
gabrielmbmb Jun 2, 2023
4da4b7a
fix: `record_id` was `str` instead of `UUID`
gabrielmbmb Jun 2, 2023
25c9cfd
fix: include response returning `[]` as query result
gabrielmbmb Jun 2, 2023
a78dc13
feat: add unit tests for search endpoint
gabrielmbmb Jun 2, 2023
32f9fc6
fix: add missing `/me`
gabrielmbmb Jun 2, 2023
b63c14e
feat: remove `_merge_search_records` function
gabrielmbmb Jun 2, 2023
b90db2b
feat: return results sorted based on score
gabrielmbmb Jun 2, 2023
43b5be8
docs: add search endpoint
gabrielmbmb Jun 2, 2023
1e9ab40
Merge branch 'develop' into feature/search-endpoint
gabrielmbmb Jun 5, 2023
7d1605f
feat: remove `Dict` import
gabrielmbmb Jun 5, 2023
c291df0
feat: update search unit tests to use `mock_search_engine`
gabrielmbmb Jun 5, 2023
50ae350
feat: allow `query_score` to be optional
gabrielmbmb Jun 5, 2023
86b27df
feat: add `user_id` to query
gabrielmbmb Jun 5, 2023
d3748bc
refactor: records sorting according to `SearchEngine` results
gabrielmbmb Jun 5, 2023
75717b3
feat: add other users `Responses`
gabrielmbmb Jun 5, 2023
fc5a2b7
feat: add `SearchRecordsQuery` schema
gabrielmbmb Jun 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ These are the section headers that we use:
- Added new status `draft` for the `Response` model.
- Added `LabelSelectionQuestionSettings` class allowing to create label selection (single-choice) questions in the API ([#3005])
- Added `MultiLabelSelectionQuestionSettings` class allowing to create multi-label selection (multi-choice) questions in the API ([#3010]).
- Added `POST /api/v1/me/datasets/{dataset_id}/records/search` endpoint ([#3068]).

### Changed

Expand Down
71 changes: 68 additions & 3 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from uuid import UUID

from fastapi import APIRouter, Depends, HTTPException, Query, Security, status
Expand All @@ -38,10 +38,20 @@
Records,
RecordsCreate,
ResponseStatusFilter,
SearchRecord,
SearchRecordsResult,
)
from argilla.server.search_engine import Query as SearchEngineQuery
from argilla.server.search_engine import (
SearchEngine,
UserResponseStatusFilter,
get_search_engine,
)
from argilla.server.search_engine import SearchEngine, get_search_engine
from argilla.server.security import auth

if TYPE_CHECKING:
from argilla.server.search_engine import SearchResponses

LIST_DATASET_RECORDS_LIMIT_DEFAULT = 50
LIST_DATASET_RECORDS_LIMIT_LTE = 1000

Expand Down Expand Up @@ -106,7 +116,7 @@ def list_current_user_dataset_records(
*,
db: Session = Depends(get_db),
dataset_id: UUID,
include: Optional[List[RecordInclude]] = Query([]),
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
offset: int = 0,
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
Expand Down Expand Up @@ -285,6 +295,61 @@ async def create_dataset_records(
raise HTTPException(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=str(err))


@router.post(
"/me/datasets/{dataset_id}/records/search",
status_code=status.HTTP_200_OK,
response_model=SearchRecordsResult,
response_model_exclude_unset=True,
)
async def search_dataset_records(
*,
db: Session = Depends(get_db),
search_engine: SearchEngine = Depends(get_search_engine),
telemetry_client: TelemetryClient = Depends(get_telemetry_client),
dataset_id: UUID,
query: SearchEngineQuery,
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
include: List[RecordInclude] = Query([]),
response_status: Optional[ResponseStatusFilter] = Query(None),
limit: int = Query(default=LIST_DATASET_RECORDS_LIMIT_DEFAULT, lte=LIST_DATASET_RECORDS_LIMIT_LTE),
current_user: User = Security(auth.get_current_user),
):
dataset = _get_dataset(db, dataset_id)
authorize(current_user, DatasetPolicyV1.search_records(dataset))

if query.text.field and not datasets.get_field_by_name_and_dataset_id(db, query.text.field, dataset_id):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Field `{query.text.field}` not found in dataset `{dataset_id}`.",
)

user_response_status_filter = None
if response_status:
user_response_status_filter = UserResponseStatusFilter(
user=current_user,
statuses=[response_status.value],
)

search_responses = await search_engine.search(
dataset=dataset,
query=query,
user_response_status_filter=user_response_status_filter,
limit=limit,
)

record_id_score_map = {response.record_id: response.score for response in search_responses.items}
records = datasets.get_records_by_ids(
db=db, dataset_id=dataset_id, record_ids=list(record_id_score_map.keys()), include=include
)

return SearchRecordsResult(
items=sorted(
[SearchRecord(record=record, query_score=record_id_score_map[record.id]) for record in records],
key=lambda x: x.query_score,
reverse=True,
)
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
)


@router.put("/datasets/{dataset_id}/publish", response_model=Dataset)
async def publish_dataset(
*,
Expand Down
13 changes: 12 additions & 1 deletion src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from fastapi.encoders import jsonable_encoder
from sqlalchemy import and_, func
from sqlalchemy.orm import Session, contains_eager, joinedload
from sqlalchemy.orm import Session, contains_eager, joinedload, noload

from argilla.server.contexts import accounts
from argilla.server.models import (
Expand Down Expand Up @@ -189,6 +189,17 @@ def get_record_by_id(db: Session, record_id: UUID):
return db.get(Record, record_id)


def get_records_by_ids(
db: Session, dataset_id: UUID, record_ids: List[UUID], include: List[RecordInclude] = []
) -> List[Record]:
query = db.query(Record).filter(Record.dataset_id == dataset_id, Record.id.in_(record_ids))
if RecordInclude.responses in include:
query = query.options(joinedload(Record.responses))
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
else:
query = query.options(noload(Record.responses))
return query.all()


def list_records_by_dataset_id(
db: Session,
dataset_id: UUID,
Expand Down
13 changes: 13 additions & 0 deletions src/argilla/server/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,19 @@ def create_question(cls, actor: User) -> bool:
def create_records(cls, actor: User) -> bool:
return actor.is_admin

@classmethod
def search_records(cls, dataset: Dataset) -> bool:
return lambda actor: (
actor.is_admin
or bool(
accounts.get_workspace_user_by_workspace_id_and_user_id(
Session.object_session(actor),
dataset.workspace_id,
actor.id,
)
)
)

@classmethod
def publish(cls, actor: User) -> bool:
return actor.is_admin
Expand Down
9 changes: 9 additions & 0 deletions src/argilla/server/schemas/v1/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,12 @@ def check_user_id_is_unique(cls, values):

class RecordsCreate(BaseModel):
items: conlist(item_type=RecordCreate, min_items=RECORDS_CREATE_MIN_ITEMS, max_items=RECORDS_CREATE_MAX_ITEMS)


class SearchRecord(BaseModel):
record: Record
query_score: float
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved


class SearchRecordsResult(BaseModel):
items: List[SearchRecord]
6 changes: 3 additions & 3 deletions src/argilla/server/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Union
from uuid import UUID

from opensearchpy import AsyncOpenSearch, helpers
Expand Down Expand Up @@ -205,7 +205,7 @@ async def search(
items = []
next_page_token = None
for hit in response["hits"]["hits"]:
items.append(SearchResponseItem(record_id=hit["_id"], score=hit["_score"]))
items.append(SearchResponseItem(record_id=UUID(hit["_id"]), score=hit["_score"]))
# See https://www.elastic.co/guide/en/elasticsearch/reference/current/paginate-search-results.html
next_page_token = hit.get("_sort")

Expand Down Expand Up @@ -278,7 +278,7 @@ def _index_name_for_dataset(dataset: Dataset):
return f"rg.{dataset.id}"


async def get_search_engine():
async def get_search_engine() -> Generator[SearchEngine, None, None]:
config = dict(
hosts=settings.elasticsearch,
verify_certs=settings.elasticsearch_ssl_verify,
Expand Down
Loading