Skip to content

Commit

Permalink
fix: parse label for dataset settings for text and token classificati…
Browse files Browse the repository at this point in the history
…on tasks (#3497)

<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

This PR prevents errors described in #3495 by parsing labels into
strings. The backend parsing is also changed to apply label
normalization after the type conversion.

Closes #3495

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [X] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

The issue code snippet has been launched locally and verified that it
works with new changes.

**Checklist**

- [X] follows the style guidelines of this project
- [X] I did a self-review of my code
- [X] My changes generate no new warnings
- [X] I have added tests that prove my fix is effective or that my
feature works
- [X] I have added relevant notes to the CHANGELOG.md file (See
https://keepachangelog.com/)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Alvaro Bartolome <alvaro@argilla.io>
  • Loading branch information
3 people authored and keithCuniah committed Aug 3, 2023
1 parent c59c68a commit 0188cbe
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 13 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ These are the section headers that we use:

## [Unreleased]

### Fixed

- `TextClassificationSettings` and `TokenClassificationSettings` labels are properly parsed to strings both in the Python client and in the backend endpoint (Closes [#3495](https://github.com/argilla-io/argilla/issues/3495)).

### Added

- Added `PATCH /api/v1/fields/{field_id}` endpoint to update the field title and markdown settings ([#3421](https://github.com/argilla-io/argilla/pull/3421)).
Expand Down
11 changes: 10 additions & 1 deletion src/argilla/client/apis/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ class LabelsSchemaSettings(_AbstractSettings):

label_schema: Set[str]

def __post_init__(self):
if not isinstance(self.label_schema, (set, list, tuple)):
raise ValueError(
f"`label_schema` is of type={type(self.label_schema)}, but type=set is preferred, and also both type=list and type=tuple are allowed."
)
self.label_schema = set([str(label) for label in self.label_schema])

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "LabelsSchemaSettings":
label_schema = data.get("label_schema", {})
Expand Down Expand Up @@ -308,7 +315,9 @@ def _save_settings(self, dataset: _DatasetApiModel, settings: Settings):
f"The provided settings type {type(settings)} cannot be applied to dataset. Task type mismatch"
)

settings_ = self._SettingsApiModel(label_schema={"labels": [label for label in settings.label_schema]})
settings_ = self._SettingsApiModel.parse_obj(
{"label_schema": {"labels": [label for label in settings.label_schema]}}
)

try:
with api_compatibility(self, min_version="1.4"):
Expand Down
2 changes: 1 addition & 1 deletion src/argilla/server/apis/v0/models/dataset_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Schema(BaseModel):

labels: Union[List[str], List[Schema]] = Field(description="A set of labels")

@validator("labels", pre=True)
@validator("labels")
def normalize_labels(cls, labels):
"""
Labels schema accept a list of strings. Those string will be converted
Expand Down
6 changes: 5 additions & 1 deletion tests/integration/test_datasets_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
TokenClassificationSettings(label_schema={"PER", "ORG"}),
TextClassificationSettings(label_schema={"A", "B"}),
),
(
TokenClassificationSettings(label_schema=[1, 2, 3]),
TextClassificationSettings(label_schema={"A", "B"}),
),
],
)
def test_settings_workflow(
Expand All @@ -59,7 +63,7 @@ def test_settings_workflow(
datasets_api = current_api.datasets

found_settings = datasets_api.load_settings(dataset)
assert found_settings == settings_
assert {label for label in found_settings.label_schema} == {str(label) for label in settings_.label_schema}

settings_.label_schema = {"LALALA"}
configure_dataset(dataset, settings_, workspace=workspace)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Optional

import pytest
from argilla._constants import API_KEY_HEADER_NAME
Expand All @@ -37,10 +37,10 @@ async def delete_dataset(client: "AsyncClient", name: str, workspace_name: str):
assert response.status_code == 200


async def create_settings(async_client: "AsyncClient", name: str, workspace_name: str):
async def create_settings(async_client: "AsyncClient", name: str, workspace_name: str, labels: Optional[list] = None):
response = await async_client.put(
f"/api/datasets/{TaskType.text_classification}/{name}/settings",
json={"label_schema": {"labels": ["Label1", "Label2"]}},
json={"label_schema": {"labels": labels or ["Label1", "Label2"]}},
params={"workspace": workspace_name},
)
return response
Expand Down Expand Up @@ -74,16 +74,17 @@ async def log_some_data(async_client: "AsyncClient", name: str, workspace_name:
return response


@pytest.mark.parametrize("labels", [["Label1", "Label2"], ["1", "2", "3"], [1, 2, 3, 4]])
@pytest.mark.asyncio
async def test_create_dataset_settings(async_client: "AsyncClient", argilla_user: User):
async def test_create_dataset_settings(async_client: "AsyncClient", argilla_user: User, labels: list):
async_client.headers.update({API_KEY_HEADER_NAME: argilla_user.api_key})
workspace_name = argilla_user.username

name = "test_create_dataset_settings"
await delete_dataset(async_client, name, workspace_name=workspace_name)
await create_dataset(async_client, name, workspace_name=workspace_name)

response = await create_settings(async_client, name, workspace_name=workspace_name)
response = await create_settings(async_client, name, workspace_name=workspace_name, labels=["Label1", "Label2"])
assert response.status_code == 200

created = response.json()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import pytest
from argilla._constants import API_KEY_HEADER_NAME
Expand All @@ -35,10 +35,10 @@ async def delete_dataset(client: "AsyncClient", name: str, workspace_name: str):
assert response.status_code == 200


async def create_settings(async_client: "AsyncClient", name: str, workspace_name: str):
async def create_settings(async_client: "AsyncClient", name: str, workspace_name: str, labels: Optional[list] = None):
response = await async_client.put(
f"/api/datasets/{TaskType.token_classification}/{name}/settings",
json={"label_schema": {"labels": ["Label1", "Label2"]}},
json={"label_schema": {"labels": labels or ["Label1", "Label2"]}},
params={"workspace": workspace_name},
)
return response
Expand Down Expand Up @@ -81,16 +81,17 @@ async def log_some_data(async_client: "AsyncClient", name: str, workspace_name:
return response


@pytest.mark.parametrize("labels", [["Label1", "Label2"], ["1", "2", "3"], [1, 2, 3, 4]])
@pytest.mark.asyncio
async def test_create_dataset_settings(async_client: "AsyncClient", argilla_user: User):
async def test_create_dataset_settings(async_client: "AsyncClient", argilla_user: User, labels: list):
async_client.headers.update({API_KEY_HEADER_NAME: argilla_user.api_key})
workspace_name = argilla_user.username

name = "test_create_dataset_settings"
await delete_dataset(async_client, name, workspace_name=workspace_name)
await create_dataset(async_client, name, workspace_name=workspace_name)

response = await create_settings(async_client, name, workspace_name=workspace_name)
response = await create_settings(async_client, name, workspace_name=workspace_name, labels=labels)
assert response.status_code == 200

created = response.json()
Expand Down

0 comments on commit 0188cbe

Please sign in to comment.