Skip to content

Commit

Permalink
feat: Allow partial record update (#2576)
Browse files Browse the repository at this point in the history
# Description

Changes in this PR affect directly the way the clients log data. If the
record info contains the `id`, the record will be partially updated.
This means that logging data can be done by passing a subset of
per-record attributes, instead of the whole record.

Closes #2535

Refs #2534

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Improvement (change adding some improvement to an existing
functionality)

**How Has This Been Tested**

Added tests covering this new feature

**Checklist**

- [x] I have merged the original branch into my forked branch
- [ ] I added relevant documentation
- [x] follows the style guidelines of this project
- [x] I did a self-review of my code
- [ ] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [x] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)
  • Loading branch information
frascuchon authored Apr 12, 2023
1 parent 1bf49ad commit 35b0916
Show file tree
Hide file tree
Showing 12 changed files with 95 additions and 153 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Argilla quickstart image dependencies are externalized into `quickstart.requirements.txt`. See [#2666](https://github.com/argilla-io/argilla/pull/2666)
- bulk endpoints will upsert data when record `id` is present. Closes [#2535](https://github.com/argilla-io/argilla/issues/2535)

## [1.6.0](https://github.com/argilla-io/argilla/compare/v1.5.1...v1.6.0)

Expand Down Expand Up @@ -93,6 +94,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The record inputs are fully visible when pagination size is one and the height of collapsed area size is bigger for laptop screen. [#2587](https://github.com/argilla-io/argilla/pull/2587/files)



### Fixes

- Allow URL to be clickable in Jupyter notebook again. Closes [#2527](https://github.com/argilla-io/argilla/issues/2527)
Expand Down
1 change: 1 addition & 0 deletions src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Datasets(AbstractApi):
_API_PREFIX = "/api/datasets"

class _DatasetApiModel(BaseModel):
id: Optional[str]
name: str
task: TaskType
owner: Optional[str] = None
Expand Down
1 change: 1 addition & 0 deletions src/argilla/client/sdk/datasets/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class BaseDatasetModel(BaseModel):


class Dataset(BaseDatasetModel):
id: str
task: TaskType
owner: str = None
workspace: str = None
Expand Down
7 changes: 4 additions & 3 deletions src/argilla/server/daos/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ def __exit__(self, exception_type, exception_value, traceback):
except bulk_error as ex:
errors = [
WrongLogDataError.Error(
reason=error_info.get("reason"),
caused_by=error_info.get("caused_by"),
reason=action_error.get("error").get("reason"),
caused_by=action_error.get("error").get("caused_by"),
)
for error in ex.errors
for error_info in [error.get("index", {}).get("error", {})]
for action_error in error.values()
if action_error.get("error")
]
raise WrongLogDataError(errors=errors)
except not_found_error as ex:
Expand Down
8 changes: 2 additions & 6 deletions src/argilla/server/daos/backend/client_adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import dataclasses
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from argilla.server.daos.backend.metrics.base import ElasticsearchMetric
from argilla.server.daos.backend.search.model import BaseQuery, SortConfig
Expand Down Expand Up @@ -179,11 +179,7 @@ def create_index(
pass

@abstractmethod
def index_documents(
self,
index: str,
docs: List[Dict[str, Any]],
) -> int:
def index_documents(self, index: str, docs: List[Dict[str, Any]]) -> int:
pass

@abstractmethod
Expand Down
30 changes: 9 additions & 21 deletions src/argilla/server/daos/backend/client_adapters/opensearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from opensearchpy import OpenSearch, helpers
from opensearchpy.exceptions import (
Expand Down Expand Up @@ -541,11 +541,7 @@ def create_index(
ignore=400,
)

def index_documents(
self,
index: str,
docs: List[Dict[str, Any]],
) -> int:
def index_documents(self, index: str, docs: List[Dict[str, Any]]) -> int:
actions = (self._doc2bulk_action(index, doc) for doc in docs)
success, failed = self.bulk(
index=index,
Expand All @@ -554,23 +550,15 @@ def index_documents(
return len(failed)

@staticmethod
def _doc2bulk_action(
index: str,
doc: Dict[str, Any],
) -> Dict[str, Any]:
def get_id(r):
return r.get("id")
def _doc2bulk_action(index: str, doc: Dict[str, Any]) -> Dict[str, Any]:
doc_id = doc.get("id")

data = {
"_op_type": "index",
"_index": index,
"_routing": None, # TODO(@frascuchon): Use a sharding routing
**doc,
}
data = (
{"_index": index, "_op_type": "index", **doc}
if doc_id is None
else {"_index": index, "_id": doc_id, "_op_type": "update", "doc_as_upsert": True, "doc": doc}
)

id = get_id(doc)
if id is not None:
data["_id"] = id
return data

def upsert_index_document(
Expand Down
12 changes: 3 additions & 9 deletions src/argilla/server/daos/backend/generic_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,13 +500,7 @@ def remove_dataset_field(self, id: str, field: str):
property=field,
)

def add_dataset_documents(
self,
id: str,
documents: List[dict],
) -> int:
def add_dataset_records(self, id: str, documents: List[dict]) -> int:
index = dataset_records_index(id)
return self.client.index_documents(
index=index,
docs=documents,
)

return self.client.index_documents(index=index, docs=documents)
12 changes: 7 additions & 5 deletions src/argilla/server/daos/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,24 @@ def add_records(
vectors_configuration = {}
for record in records:
metadata_values.update(record.metadata or {})

db_record = record_class.parse_obj(record)
db_record.last_updated = now
record_dict = db_record.dict(
exclude_none=False,
exclude=set(exclude_fields),
)

record_dict = db_record.dict(exclude_none=True, exclude=set(exclude_fields))

if record.vectors is not None:
# TODO: Create embeddings config by settings
for (
vector_name,
vector_data_mapping,
) in record.vectors.items():
vector_dimension = vectors_configuration.get(vector_name, None)

if vector_dimension is None:
dimension = len(vector_data_mapping.value)
vectors_configuration[vector_name] = dimension

documents.append(record_dict)

self._es.create_dataset(
Expand All @@ -121,7 +123,7 @@ def add_records(
vectors_cfg=vectors_configuration,
)

return self._es.add_dataset_documents(
return self._es.add_dataset_records(
id=dataset.id,
documents=documents,
)
Expand Down
17 changes: 4 additions & 13 deletions tests/client/functional_tests/test_record_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from time import sleep

import pytest
from argilla.client.api import active_api
from argilla.client.sdk.commons.errors import NotFoundApiError


def test_partial_update_with_not_found(
mocked_client,
gutenberg_spacy_ner,
):
def test_partial_update_with_not_found(mocked_client, gutenberg_spacy_ner):
with pytest.raises(NotFoundApiError):
active_api().datasets.update_record(
name=gutenberg_spacy_ner,
Expand All @@ -29,10 +27,7 @@ def test_partial_update_with_not_found(
)


def test_partial_record_update(
mocked_client,
gutenberg_spacy_ner,
):
def test_partial_record_update(mocked_client, gutenberg_spacy_ner):
expected_id = "00c27206-da48-4fc3-aab7-4b730628f8ac"

record = record_data_by_id(
Expand Down Expand Up @@ -73,11 +68,7 @@ def test_partial_record_update(
}


def record_data_by_id(
*,
dataset: str,
record_id: str,
):
def record_data_by_id(*, dataset: str, record_id: str):
data = active_api().datasets.scan(
name=dataset,
query=f"id: {record_id}",
Expand Down
30 changes: 9 additions & 21 deletions tests/client/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,39 +525,27 @@ def test_update_record(mocked_client):
mocked_client.delete(f"/api/datasets/{dataset}")

expected_inputs = ["This is a text"]
record = rg.TextClassificationRecord(
id=0,
inputs=expected_inputs,
annotation_agent="test",
annotation=["T"],
)
api.log(
record,
name=dataset,
)
record = rg.TextClassificationRecord(id=0, inputs=expected_inputs, annotation_agent="test", annotation=["T"])
api.log(record, name=dataset)

df = api.load(name=dataset)
df = df.to_pandas()
records = df.to_dict(orient="records")
assert len(records) == 1
assert records[0]["annotation"] == "T"
# This record will replace the old one
record = rg.TextClassificationRecord(
id=0,
inputs=expected_inputs,
)
# This record will be partially updated
record = rg.TextClassificationRecord(id=0, inputs=expected_inputs, metadata={"a": "value"})

api.log(
record,
name=dataset,
)
api.log(record, name=dataset)

df = api.load(name=dataset)
df = df.to_pandas()
records = df.to_dict(orient="records")

assert len(records) == 1
assert records[0]["annotation"] is None
assert records[0]["annotation_agent"] is None
assert records[0]["annotation"] == "T"
assert records[0]["annotation_agent"] == "test"
assert records[0]["metadata"] == {"a": "value"}


def test_text_classifier_with_inputs_list(mocked_client):
Expand Down
Loading

0 comments on commit 35b0916

Please sign in to comment.