Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug in DatasetForTextClassification._prepare_for_training_with_transformers when multi_label=True #2606

Closed
alvarobartt opened this issue Mar 27, 2023 · 1 comment
Assignees
Milestone

Comments

@alvarobartt
Copy link
Member

Describe the bug

When calling .prepare_for_training() with the default arguments on a DatasetForTextClassification dataset with multi_label=True records, the _prepare_for_training_with_transformers function fails. This is due to the recent addition of the context key when building the 🤗Dataset.

Then, IMO the following line should either be removed or made optional just in case the context key is available.

https://github.com/argilla-io/argilla/blame/ba4ae63dc43ae1d36aa0efcf03896bcf5d9206ee/src/argilla/client/datasets.py#L764.

So on, the bug was introduced in 1.5.0, as in 1.4.0 it works fine and the CI seems to be passing.

To Reproduce

Run the following script to reproduce:

import argilla as rg

rb_dataset = rg.DatasetForTextClassification([
    rg.TextClassificationRecord(
        text="This is a sample text",
        annotation="pos",
        multi_label=True,
    )
])
rb_dataset.prepare_for_training()

Expected behavior

The context key shouldn't be mandatory, so as to ensure consistency with the previous datasets. So on, the code above should work and the CI should pass.

Environment (please complete the following information):

  • OS [e.g. iOS]: macOS
  • Browser [e.g. chrome, safari]: N/A
  • Argilla Version [e.g. 1.0.0]: 1.5.0
  • ElasticSearch Version [e.g. 7.10.2]: N/A
  • Docker Image (optional) [e.g. argilla:v1.0.0]: N/A
@tomaarsen
Copy link
Contributor

Thank you for reporting this!

See the Stack trace here
 & C:/Users/tom/.conda/envs/argilla/python.exe c:/code/argilla/issue_2606.py
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ c:/code/argilla/issue_2606.py:10 in <module>                                                     │
│                                                                                                  │
│    7 │   │   multi_label=True,                                                                   │
│    8 │   )                                                                                       │
│    9 ])                                                                                          │
│ ❱ 10 rb_dataset.prepare_for_training()                                                           │
│   11                                                                                             │
│                                                                                                  │
│ C:\code\argilla\src\argilla\client\datasets.py:431 in prepare_for_training                       │
│                                                                                                  │
│    428 │   │                                                                                     │
│    429 │   │   # prepare for training for the right method                                       │
│    430 │   │   if framework is Framework.TRANSFORMERS:                                           │
│ ❱  431 │   │   │   return self._prepare_for_training_with_transformers(train_size=train_size, t  │
│    432 │   │   elif framework is Framework.SPACY and lang is None:                               │
│    433 │   │   │   raise ValueError(                                                             │
│    434 │   │   │   │   "Please provide a spacy language model to prepare the" " dataset for tra  │
│                                                                                                  │
│ C:\code\argilla\src\argilla\utils\dependency.py:128 in check_if_installed                        │
│                                                                                                  │
│   125 │   @functools.wraps(func)                                                                 │
│   126 │   def check_if_installed(*args, **kwargs):                                               │
│   127 │   │   require_version(requirement, func.__name__)                                        │
│ ❱ 128 │   │   return func(*args, **kwargs)                                                       │
│   129 │                                                                                          │
│   130 │   return check_if_installed                                                              │
│   131                                                                                            │
│                                                                                                  │
│ C:\code\argilla\src\argilla\client\datasets.py:764 in _prepare_for_training_with_transformers    │
│                                                                                                  │
│    761 │   │   │   │   {                                                                         │
│    762 │   │   │   │   │   "id": ds["id"],                                                       │
│    763 │   │   │   │   │   "text": ds["text"],                                                   │
│ ❱  764 │   │   │   │   │   "context": ds_dict["context"],                                        │
│    765 │   │   │   │   │   "label": labels,                                                      │
│    766 │   │   │   │   │   "binarized_label": binarized_labels,                                  │
│    767 │   │   │   │   },                                                                        │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'context'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants