Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add metadata to Record #3194

Merged
merged 29 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
8a0e370
feat: add `metadata` column to `Record` table
gabrielmbmb Jun 14, 2023
f71bb41
feat: return `metadata` from listing records endoints
gabrielmbmb Jun 14, 2023
c876839
feat: update listing records unit tests
gabrielmbmb Jun 14, 2023
c4ebbc4
feat: add `metadata` attribute to `FeedbackRecord`
gabrielmbmb Jun 14, 2023
a5e0703
feat: add `metadata` attribute to `FeedbackItemModel`
gabrielmbmb Jun 14, 2023
b1ed498
feat: add `metadata` to unit tests
gabrielmbmb Jun 14, 2023
891f0a8
feat: add `metadata` to HF dataset
gabrielmbmb Jun 14, 2023
4eb187b
fix: `test_get_records` dataset was not being published
gabrielmbmb Jun 14, 2023
f63346a
Merge branch 'develop' into feature/api-record-metadata-column
gabrielmbmb Jun 14, 2023
64a922f
feat: update `from_huggingface` to include metadata
gabrielmbmb Jun 14, 2023
9501f23
feat: update `feedback_dataset_in_argilla` return
gabrielmbmb Jun 14, 2023
ef1960b
refactor: update if conditions
gabrielmbmb Jun 14, 2023
91774d3
feat: split string in two lines
gabrielmbmb Jun 14, 2023
5c90fa6
feat: update push/from methods tests after adding metadata
gabrielmbmb Jun 14, 2023
4e2d8a2
docs: mention `metadata` attribute
gabrielmbmb Jun 15, 2023
023d748
docs: add `metadata` attribute
gabrielmbmb Jun 15, 2023
819da07
fix: wrong assertion
gabrielmbmb Jun 15, 2023
0a2fd4d
feat: improve exception traceback
gabrielmbmb Jun 15, 2023
af1f75f
fix: remove additional whitespace
gabrielmbmb Jun 15, 2023
1f8562f
feat: remove `1.10.0` revision (will be included `1.11.0`)
gabrielmbmb Jun 15, 2023
1370614
Merge branch 'feature/api-record-metadata-column' of https://github.c…
gabrielmbmb Jun 15, 2023
cb4ae71
docs: move `metadata` to `1.11.0`
gabrielmbmb Jun 15, 2023
c8d4b7f
feat: remove `RecordGetterDict` class
gabrielmbmb Jun 15, 2023
c4399ea
feat: update `delete_and_raise_exception` to `delete_dataset`
gabrielmbmb Jun 15, 2023
af8a4dd
feat: add `RecordGetterDict` class and remove custom `from_orm` methods
gabrielmbmb Jun 16, 2023
367fb64
fix: `metadata` should be `metadata_`
gabrielmbmb Jun 16, 2023
66aa818
Merge branch 'develop' into feature/api-record-metadata-column
gabrielmbmb Jun 16, 2023
9833fd2
fix: default value should be `None`
gabrielmbmb Jun 16, 2023
a437ddf
docs: remove `metadata` from `1.10.0`
gabrielmbmb Jun 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ These are the section headers that we use:

## [1.10.0](https://github.com/argilla-io/argilla/compare/v1.9.0...v1.10.0)

### Added

- Added `metadata` attribute to the `Record` of the `FeedbackDataset` ([#3194](https://github.com/argilla-io/argilla/pull/3194))

### Changed

- Updated `SearchEngine` and `POST /api/v1/me/datasets/{dataset_id}/records/search` to return the `total` number of records matching the search query ([#3166](https://github.com/argilla-io/argilla/pull/3166))
Expand Down
2 changes: 2 additions & 0 deletions docs/_source/guides/llms/practical_guides/create_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Take some time to inspect the data before adding it to the dataset in case this
The next step is to create records following Argilla's `FeedbackRecord` format. These are the attributes of a `FeedbackRecord`:

- `fields`: A dictionary with the name (key) and content (value) of each of the fields in the record. These will need to match the fields set up in the dataset configuration (see [Define record fields](#define-record-fields)).
- `metadata` (optional): A dictionary with the metadata of the record. This can include any information about the record that is not part of the fields. For example, the source of the record or the date it was created. If there is no metadata, this will be `None`.
- `external_id` (optional): An ID of the record defined by the user. If there is no external ID, this will be `None`.
- `responses` (optional): A list of all responses to a record. There is no need to configure this when creating a record, it will be filled automatically with the responses collected from the Argilla UI.

Expand All @@ -144,6 +145,7 @@ record = rg.FeedbackRecord(
"question": "Why can camels survive long without water?",
"answer": "Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time."
},
metadata={"source": "encyclopedia"},
external_id=None
)
```
Expand Down
40 changes: 25 additions & 15 deletions src/argilla/client/feedback/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def push_to_argilla(self, name: Optional[str] = None, workspace: Optional[Union[
"""
httpx_client: "httpx.Client" = rg.active_client().http_client.httpx

if not name or (not name and not workspace):
if name is None:
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
if self.argilla_id is None:
_LOGGER.warning(
"No `name` or `workspace` have been provided, and no dataset has"
Expand All @@ -515,19 +515,19 @@ def push_to_argilla(self, name: Optional[str] = None, workspace: Optional[Union[
self.__records += self.__new_records
self.__new_records = []
except Exception as e:
Exception(
raise Exception(
"Failed while adding new records to the current `FeedbackTask`"
f" dataset in Argilla with exception: {e}"
)
elif name or (name and workspace):
else:
if workspace is None:
workspace = rg.Workspace.from_name(rg.active_client().get_workspace())

if isinstance(workspace, str):
workspace = rg.Workspace.from_name(workspace)

dataset_exists, _ = feedback_dataset_in_argilla(name=name, workspace=workspace)
if dataset_exists:
dataset = feedback_dataset_in_argilla(name=name, workspace=workspace)
if dataset is not None:
raise RuntimeError(
f"Dataset with name=`{name}` and workspace=`{workspace.name}`"
" already exists in Argilla, please choose another name and/or"
Expand Down Expand Up @@ -611,12 +611,10 @@ def delete_and_raise_exception(dataset_id: UUID, exception: Exception) -> None:

if self.argilla_id is not None:
_LOGGER.warning(
"Since the current object is already a `FeedbackDataset` pushed to"
" Argilla, you'll keep on interacting with the same dataset in"
" Argilla, even though the one you just pushed holds a different"
f" ID ({argilla_id}). So on, if you want to switch to the newly"
" pushed `FeedbackDataset` instead, please use"
f" `FeedbackDataset.from_argilla(id='{argilla_id}')`."
"Since the current object is already a `FeedbackDataset` pushed to Argilla, you'll keep on"
" interacting with the same dataset in Argilla, even though the one you just pushed holds a"
f" different ID ({argilla_id}). So on, if you want to switch to the newly pushed `FeedbackDataset`"
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
f" instead, please use `FeedbackDataset.from_argilla(id='{argilla_id}')`."
)
return
self.argilla_id = argilla_id
Expand Down Expand Up @@ -656,8 +654,8 @@ def from_argilla(
"""
httpx_client: "httpx.Client" = rg.active_client().http_client.httpx

dataset_exists, existing_dataset = feedback_dataset_in_argilla(name=name, workspace=workspace, id=id)
if not dataset_exists:
existing_dataset = feedback_dataset_in_argilla(name=name, workspace=workspace, id=id)
if existing_dataset is None:
raise ValueError(
f"Could not find a `FeedbackTask` dataset in Argilla with name='{name}'."
if name and not workspace
Expand Down Expand Up @@ -729,7 +727,7 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset":
if format == "datasets":
from datasets import Dataset, Features, Sequence, Value

dataset = {}
dataset = {"metadata": []}
features = {}
for field in self.fields:
if field.settings["type"] not in FIELD_TYPE_TO_PYTHON_TYPE.keys():
Expand Down Expand Up @@ -782,8 +780,14 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset":
]
or None
)
dataset["metadata"].append(json.dumps(record.metadata) if record.metadata else None)
dataset["external_id"].append(record.external_id or None)

if len(dataset["metadata"]) > 0:
features["metadata"] = Value(dtype="string")
else:
del dataset["metadata"]

return Dataset.from_dict(
dataset,
features=Features(features),
Expand Down Expand Up @@ -860,7 +864,7 @@ def push_to_huggingface(self, repo_id: str, generate_card: Optional[bool] = True
@classmethod
@requires_version("datasets")
@requires_version("huggingface_hub")
def from_huggingface(cls, repo_id: str, *args, **kwargs) -> "FeedbackDataset":
def from_huggingface(cls, repo_id: str, *args: Any, **kwargs: Any) -> "FeedbackDataset":
"""Loads a `FeedbackDataset` from the HuggingFace Hub.

Args:
Expand Down Expand Up @@ -936,9 +940,15 @@ def from_huggingface(cls, repo_id: str, *args, **kwargs) -> "FeedbackDataset":
"values": {},
}
responses[user_id]["values"].update({question.name: {"value": value}})

metadata = None
if "metadata" in hfds[index] and hfds[index]["metadata"] is not None:
metadata = json.loads(hfds[index]["metadata"])

cls.__records.append(
FeedbackRecord(
fields={field.name: hfds[index][field.name] for field in cls.fields},
metadata=metadata,
responses=list(responses.values()) or None,
external_id=hfds[index]["external_id"],
)
Expand Down
4 changes: 4 additions & 0 deletions src/argilla/client/feedback/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,30 @@ class FeedbackRecord(BaseModel):

Args:
fields (Dict[str, str]): The fields of the record.
metadata (Optional[Dict[str, Any]]): The metadata of the record. Defaults to None.
responses (Optional[Union[ResponseSchema, List[ResponseSchema]]]): The responses of the record. Defaults to None.
external_id (Optional[str]): The external id of the record. Defaults to None.

Examples:
>>> import argilla as rg
>>> rg.FeedbackRecord(
... fields={"text": "This is the first record", "label": "positive"},
... metadata={"first": True, "nested": {"more": "stuff"}},
... responses=[{"values": {"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}}],
... external_id="entry-1",
... )
>>> # or use a ResponseSchema directly
>>> rg.FeedbackRecord(
... fields={"text": "This is the first record", "label": "positive"},
... metadata={"first": True, "nested": {"more": "stuff"}},
... responses=[rg.ResponseSchema(values={"question-1": {"value": "This is the first answer"}, "question-2": {"value": 5}}))],
... external_id="entry-1",
... )

"""

fields: Dict[str, str]
metadata: Optional[Dict[str, Any]] = None
responses: Optional[Union[ResponseSchema, List[ResponseSchema]]] = None
external_id: Optional[str] = None

Expand Down
35 changes: 14 additions & 21 deletions src/argilla/client/feedback/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def feedback_dataset_in_argilla(
*,
workspace: Optional[Union[str, rg.Workspace]] = None,
id: Optional[str] = None,
) -> Tuple[bool, Optional["FeedbackDatasetModel"]]:
) -> Union["FeedbackDatasetModel", None]:
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
"""Checks whether a `FeedbackDataset` exists in Argilla or not, based on the `name`, `id`, or the combination of
`name` and `workspace`.

Expand All @@ -88,8 +88,7 @@ def feedback_dataset_in_argilla(
id: the Argilla ID of the `FeedbackDataset`.

Returns:
A tuple with a boolean indicating whether the `FeedbackDataset` exists in Argilla or not, and the `FeedbackDatasetModel`
object if the `FeedbackDataset` exists, `None` otherwise.
The `FeedbackDataset` if it exists in Argilla, `None` otherwise.

Raises:
ValueError: if the `workspace` is not a `rg.Workspace` instance or a string.
Expand All @@ -99,24 +98,16 @@ def feedback_dataset_in_argilla(
>>> import argilla as rg
>>> rg.init(api_url="...", api_key="...")
>>> from argilla.client.feedback.dataset import feedback_dataset_in_argilla
>>> fds_exists, fds_cls = feedback_dataset_in_argilla(name="my-dataset")
>>> dataset = feedback_dataset_in_argilla(name="my-dataset")
"""
assert (name and workspace) or name or id, (
"You must provide either the `name` and `workspace` (the latter just if"
" applicable, if not the default `workspace` will be used) or the `id`, which"
" is the Argilla ID of the `rg.FeedbackDataset`."
)

httpx_client: "httpx.Client" = rg.active_client().http_client.httpx

if (name and workspace) or name:
if name:
if workspace is None:
workspace = rg.Workspace.from_name(rg.active_client().get_workspace())

if isinstance(workspace, str):
elif isinstance(workspace, str):
workspace = rg.Workspace.from_name(workspace)

if not isinstance(workspace, rg.Workspace):
elif not isinstance(workspace, rg.Workspace):
raise ValueError(f"Workspace must be a `rg.Workspace` instance or a string, got {type(workspace)}")

try:
Expand All @@ -126,10 +117,12 @@ def feedback_dataset_in_argilla(

for dataset in datasets:
if dataset.name == name and dataset.workspace_id == workspace.id:
return True, dataset
return False, None
else:
return dataset
return None
elif id:
try:
return True, datasets_api_v1.get_dataset(client=httpx_client, id=id).parsed
except:
return False, None
return datasets_api_v1.get_dataset(client=httpx_client, id=id).parsed
except Exception:
return None
else:
raise ValueError("You must provide either the `name` and `workspace` or the `id` of the `FeedbackDataset`.")
3 changes: 2 additions & 1 deletion src/argilla/client/sdk/v1/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ def add_records(
if response["user_id"] is None:
if response_without_user_id:
warnings.warn(
f"Multiple responses without `user_id` found in record {record}, so just the first one will be used while the rest will be ignored."
f"Multiple responses without `user_id` found in record {record}, so just the first one will be"
" used while the rest will be ignored."
)
continue
else:
Expand Down
1 change: 1 addition & 0 deletions src/argilla/client/sdk/v1/datasets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class FeedbackResponseModel(BaseModel):
class FeedbackItemModel(BaseModel):
id: UUID
fields: Dict[str, Any]
metadata: Optional[Dict[str, Any]] = None
external_id: Optional[str] = None
responses: Optional[List[FeedbackResponseModel]] = []
inserted_at: datetime
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""add record metadata column

Revision ID: 3ff6484f8b37
Revises: ae5522b4c674
Create Date: 2023-06-14 13:02:41.735153

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = "3ff6484f8b37"
down_revision = "ae5522b4c674"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("records", sa.Column("metadata", sa.JSON(), nullable=True))
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("records", "metadata")
# ### end Alembic commands ###
7 changes: 4 additions & 3 deletions src/argilla/server/apis/v1/handlers/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
Question,
QuestionCreate,
Questions,
Record,
RecordInclude,
Records,
RecordsCreate,
Expand Down Expand Up @@ -127,7 +128,7 @@ def list_current_user_dataset_records(
db, dataset_id, current_user.id, include=include, response_status=response_status, offset=offset, limit=limit
)

return Records(items=[record.__dict__ for record in records])
return Records.from_orm(records)


@router.get("/datasets/{dataset_id}/records", response_model=Records, response_model_exclude_unset=True)
Expand All @@ -146,7 +147,7 @@ def list_dataset_records(

records = datasets.list_records_by_dataset_id(db, dataset_id, include=include, offset=offset, limit=limit)

return Records(items=[record.__dict__ for record in records])
return Records.from_orm(records)


@router.get("/datasets/{dataset_id}", response_model=Dataset)
Expand Down Expand Up @@ -349,7 +350,7 @@ async def search_dataset_records(

for record in records:
record_id_score_map[record.id]["search_record"] = SearchRecord(
record=record.__dict__, query_score=record_id_score_map[record.id]["query_score"]
record=Record.from_orm(record), query_score=record_id_score_map[record.id]["query_score"]
)

return SearchRecordsResult(
Expand Down
1 change: 1 addition & 0 deletions src/argilla/server/contexts/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ async def create_records(

record = Record(
fields=record_create.fields,
metadata_=record_create.metadata,
external_id=record_create.external_id,
dataset_id=dataset.id,
)
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from sqlalchemy.orm import Session

ALEMBIC_CONFIG_FILE = os.path.normpath(os.path.join(os.path.dirname(argilla.__file__), "alembic.ini"))
TAGGED_REVISIONS = OrderedDict({"1.7": "1769ee58fbb4", "1.8": "ae5522b4c674"})
TAGGED_REVISIONS = OrderedDict({"1.7": "1769ee58fbb4", "1.8": "ae5522b4c674", "1.10": "3ff6484f8b37"})
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved


@event.listens_for(Engine, "connect")
Expand Down
1 change: 1 addition & 0 deletions src/argilla/server/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class Record(Base):

id: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
fields: Mapped[dict] = mapped_column(JSON, default={})
metadata_: Mapped[Optional[dict]] = mapped_column("metadata", JSON, nullable=True)
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
external_id: Mapped[Optional[str]] = mapped_column(index=True)
dataset_id: Mapped[UUID] = mapped_column(ForeignKey("datasets.id", ondelete="CASCADE"), index=True)

Expand Down
Loading