Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add training task for qna #3740

Merged
merged 13 commits into from
Sep 11, 2023
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ These are the section headers that we use:

### Added

- Added `ArgillaTrainer` integration with `TrainingTask.for_question_answering` ([#3740](https://github.com/argilla-io/argilla/pull/3740))
- Added `Auto save record` to save automatically the current record that you are working on ([#3541](https://github.com/argilla-io/argilla/pull/3541))
- Added `ArgillaTrainer` integration with OpenAI, allowing fine tuning for chat completion ([#3615](https://github.com/argilla-io/argilla/pull/3615))
- Added `workspaces list` command to list Argilla workspaces ([#3594](https://github.com/argilla-io/argilla/pull/3594)).
Expand Down
127 changes: 116 additions & 11 deletions docs/_source/guides/llms/practical_guides/fine_tune.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ We plan on adding more support for other tasks and frameworks so feel free to re
| Task/Framework | TRL | OpenAI | SetFit | spaCy | Transformers | PEFT |
|:--------------------------------|:-----|:-------|:-------|:------|:-------------|:-----|
| Text Classification | | | ✔️ | ✔️ | ✔️ | ✔️ |
| Question Answering | | | | | ✔️ | |
| Supervised Fine-tuning | ✔️ | | | | | |
| Reward Modeling | ✔️ | | | | | |
| Proximal Policy Optimization | ✔️ | | | | | |
Expand Down Expand Up @@ -70,11 +71,12 @@ A `TrainingTask` is used to define how the data should be processed and formatte
| Method | Content | `formatting_func` return type | Default|
|:-----------------------------------|:-----------------------------|:---------------------------------------------------------------------------------|:-------|
| for_text_classification | `text-label` | `Union[Tuple[str, str], Tuple[str, List[str]]]` | ✔️ |
| for_supervised_fine_tuning | `text` | `Union[str, Iterator[str]]` | ✗ |
| for_reward_modeling | `chosen-rejected` | `Union[Tuple[str, str], Iterator[Tuple[str, str]]]` | ✗ |
| for_proximal_policy_optimization | `text` | `Union[str, Iterator[str]]]` | ✗ |
| for_direct_preference_optimization | `prompt-chosen-rejected` | `Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]` | ✗ |
| for_chat_completion | `chat-turn-role-content` | `Union[Tuple[str, str, str, str], Iterator[Tuple[str, str, str, str]]]`| ✗ |
| for_question_answering | `questio-context-answer` | `Union[Tuple[str, str], Tuple[str, List[str]]]` | ✔️ |
| for_supervised_fine_tuning | `text` | `Union[str, Iterator[str]]` | ✗ |
| for_reward_modeling | `chosen-rejected` | `Union[Tuple[str, str], Iterator[Tuple[str, str]]]` | ✗ |
| for_proximal_policy_optimization | `text` | `Union[str, Iterator[str]]]` | ✗ |
| for_direct_preference_optimization | `prompt-chosen-rejected` | `Union[Tuple[str, str, str], Iterator[Tuple[str, str, str]]]` | ✗ |
| for_chat_completion | `chat-turn-role-content` | `Union[Tuple[str, str, str, str], Iterator[Tuple[str, str, str, str]]]` | ✗ |


## Tasks
Expand Down Expand Up @@ -111,7 +113,7 @@ For a multi-label scenario it is recommended to add some examples without any la

::::

We then use either `text-label`-pair to further fine-tune the model.
We then use either `text-label`-pair or a `formatting_func` to further fine-tune the model.

#### Training

Expand Down Expand Up @@ -175,9 +177,9 @@ def formatting_func(sample):
element for element, frequency in most_common if frequency == max_frequency
]
label = random.choice(most_common_elements)
return (text, label)
yield (text, label)
else:
return None
yield None

task = TrainingTask.for_text_classification(formatting_func=formatting_func)
```
Expand All @@ -202,6 +204,109 @@ trainer = ArgillaTrainer(
trainer.train(output_dir="textcat_model")
```

### Question Answering

#### Background

The extractive Question Answering (QnA) task involves answering questions posed by users based on a given context. It is a challenging task that requires the model to understand the context of the question and provide an accurate answer. The model must be able to comprehend the question and the context in which it is asked, as well as the relationship between the two. Additionally, it must be able to extract the relevant information from the context and provide an answer that is both accurate and relevant to the question.

Underneath you can find a sample of an extractive QnA dataset underneath:

```batch
{
'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
'answers': 'Saint Bernadette Soubirous',
}
```

```{note}
Officially, answers need to be passed as a list of `{'answer_start': int, 'text': str}`-dicts. However, we only support a string, where the `answer_start` is inferred from the `context` and `text`-field.
```

We then use either `question-context-answer`-set or a `formatting_func` to further fine-tune the model.

#### Training

**Data Preparation**

```python
import argilla as rg
from datasets import Dataset

feedback_dataset = rg.FeedbackDataset.from_huggingface("argilla/squad")
```

We can use a default configuration where we initialize the `TrainingTask.for_question_answering` using the `question-context-answer`-set from the dataset. We also offer the option to provide a `formatting_func` to the `TrainingTask.for_question_asnwering`. This function is applied to each sample in the dataset and can be used for advanced preprocessing and data formatting. The function should return a `question-context-answer`-set as `str-str-str`.

:::: {tab-set}

::: {tab-item} question-context-answer-set

```python
from argilla.feedback import TrainingTask

task = TrainingTask.for_question_answering(
question=feedback_dataset.field_by_name("question"),
context=feedback_dataset.field_by_name("context"),
answer=feedback_dataset.question_by_name("answer"),
)
```

:::

::: {tab-item} formatting_func

```python
from argilla.feedback import TrainingTask

def formatting_func(sample):
question = sample["question"]
context = sample["context"]
for answer in sample["answer"]:
if not all([question, context, answer["value"]]):
continue
yield question, context, answer["value"]

task = TrainingTask.for_question_answering(formatting_func=formatting_func)
```

:::

::::

**ArgillaTrainer**

Next, we can define our `ArgillaTrainer` for any of [the supported frameworks](fine_tune.md#training-configs) and [customize the training config](#supported-frameworks) using `ArgillaTrainer.update_config`.

```python
from argilla.feedback import ArgillaTrainer

trainer = ArgillaTrainer(
dataset=feedback_dataset,
task=task,
framework="transformers",
train_size=0.8,
)

trainer.train(output_dir="qna_model")

```

**Inference**

Lastly, this model can be used for inference using the `pipeline`-method from the [Transformers library](https://huggingface.co/tasks/question-answering). We can use the `question-answering`-pipeline for this task.

```python
from transformers import pipeline

qa_model = pipeline("question-answering", model="qna_model")
question = "Where do I live?"
context = "My name is Merve and I live in İstanbul."
qa_model(question = question, context = context)
## {'answer': 'İstanbul', 'end': 39, 'score': 0.953, 'start': 31}
```

### Supervised finetuning

#### Background
Expand Down Expand Up @@ -294,7 +399,7 @@ template = """\

def formatting_func(sample: Dict[str, Any]) -> str:
# What `sample` looks like depends a lot on your FeedbackDataset fields and questions
return template.format(
yield template.format(
instruction=sample["new-instruction"][0]["value"],
context=sample["new-context"][0]["value"],
response=sample["new-response"][0]["value"],
Expand Down Expand Up @@ -910,13 +1015,13 @@ def formatting_func(sample: dict) -> Union[Tuple[str, str, str, str], List[Tuple
if sample["response"]:
chat = str(uuid4())
user_message = user_message_prompt.format(context_str=sample["context"], query_str=sample["user-message"])
return [
yield [
(chat, "0", "system", system_prompt),
(chat, "1", "user", user_message),
(chat, "2", "assistant", sample["response"][0]["value"])
]
else:
return None
yield None
davidberenstein1957 marked this conversation as resolved.
Show resolved Hide resolved

task = TrainingTask.for_chat_completion(formatting_func=formatting_func)
```
Expand Down
51 changes: 4 additions & 47 deletions src/argilla/client/feedback/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import logging
import warnings
from abc import ABC, abstractproperty
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union

Expand All @@ -23,27 +22,18 @@
from argilla.client.feedback.schemas import (
FeedbackRecord,
FieldSchema,
LabelQuestion,
MultiLabelQuestion,
RankingQuestion,
RatingQuestion,
)
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes
from argilla.client.feedback.training.schemas import (
TrainingTaskForChatCompletion,
TrainingTaskForDPO,
TrainingTaskForPPO,
TrainingTaskForQuestionAnswering,
TrainingTaskForRM,
TrainingTaskForSFT,
TrainingTaskForTextClassification,
TrainingTaskTypes,
)
from argilla.client.feedback.unification import (
LabelQuestionStrategy,
MultiLabelQuestionStrategy,
RankingQuestionStrategy,
RatingQuestionStrategy,
)
from argilla.client.feedback.utils import generate_pydantic_schema
from argilla.client.models import Framework
from argilla.utils.dependency import require_version, requires_version
Expand Down Expand Up @@ -272,42 +262,6 @@ def format_as(self, format: Literal["datasets"]) -> "Dataset":
return self._huggingface_format(self)
raise ValueError(f"Unsupported format '{format}'.")

# TODO(davidberenstein1957): detatch unification into a mixin
def unify_responses(
self,
question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion],
strategy: Union[
str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy
],
) -> None:
"""
The `unify_responses` function takes a question and a strategy as input and applies the strategy
to unify the responses for that question.

Args:
question The `question` parameter can be either a string representing the name of the
question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`,
`RatingQuestion`, `RankingQuestion`).
strategy The `strategy` parameter is used to specify the strategy to be used for unifying
responses for a given question. It can be either a string or an instance of a strategy class.
"""
if isinstance(question, str):
question = self.question_by_name(question)

if isinstance(strategy, str):
if isinstance(question, LabelQuestion):
strategy = LabelQuestionStrategy(strategy)
elif isinstance(question, MultiLabelQuestion):
strategy = MultiLabelQuestionStrategy(strategy)
elif isinstance(question, RatingQuestion):
strategy = RatingQuestionStrategy(strategy)
elif isinstance(question, RankingQuestion):
strategy = RankingQuestionStrategy(strategy)
else:
raise ValueError(f"Question {question} is not supported yet")

strategy.unify_responses(self.records, question)

# TODO(alvarobartt,davidberenstein1957): we should consider having something like
# `export(..., training=True)` to export the dataset records in any format, replacing
# both `format_as` and `prepare_for_training`
Expand Down Expand Up @@ -361,6 +315,9 @@ def prepare_for_training(
if isinstance(task, TrainingTaskForTextClassification):
if task.formatting_func is None:
self.unify_responses(question=task.label.question, strategy=task.label.strategy)
elif isinstance(task, TrainingTaskForQuestionAnswering):
if task.formatting_func is None:
self.unify_responses(question=task.answer.name, strategy="disagreement")
elif not isinstance(
task,
(
Expand Down
4 changes: 2 additions & 2 deletions src/argilla/client/feedback/dataset/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from argilla.client.feedback.constants import FETCHING_BATCH_SIZE
from argilla.client.feedback.dataset.base import FeedbackDatasetBase
from argilla.client.feedback.dataset.mixins import ArgillaMixin
from argilla.client.feedback.dataset.mixins import ArgillaMixin, UnificationMixin
from argilla.client.feedback.schemas.types import AllowedFieldTypes, AllowedQuestionTypes

if TYPE_CHECKING:
from argilla.client.feedback.schemas import FeedbackRecord


class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin):
class FeedbackDataset(FeedbackDatasetBase, ArgillaMixin, UnificationMixin):
def __init__(
self,
*,
Expand Down
46 changes: 46 additions & 0 deletions src/argilla/client/feedback/dataset/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@
TextQuestion,
)
from argilla.client.feedback.schemas.types import AllowedQuestionTypes
from argilla.client.feedback.unification import (
LabelQuestionStrategy,
MultiLabelQuestionStrategy,
RankingQuestionStrategy,
RatingQuestionStrategy,
TextQuestionStrategy,
)
from argilla.client.feedback.utils import feedback_dataset_in_argilla
from argilla.client.sdk.v1.datasets import api as datasets_api_v1
from argilla.client.workspaces import Workspace
Expand Down Expand Up @@ -339,3 +346,42 @@ def list(cls: Type["FeedbackDataset"], workspace: Optional[str] = None) -> List[
)
for dataset in datasets
]


class UnificationMixin:
def unify_responses(
self,
question: Union[str, LabelQuestion, MultiLabelQuestion, RatingQuestion],
strategy: Union[
str, LabelQuestionStrategy, MultiLabelQuestionStrategy, RatingQuestionStrategy, RankingQuestionStrategy
],
) -> None:
"""
The `unify_responses` function takes a question and a strategy as input and applies the strategy
to unify the responses for that question.

Args:
question The `question` parameter can be either a string representing the name of the
question, or an instance of one of the question classes (`LabelQuestion`, `MultiLabelQuestion`,
`RatingQuestion`, `RankingQuestion`).
strategy The `strategy` parameter is used to specify the strategy to be used for unifying
responses for a given question. It can be either a string or an instance of a strategy class.
"""
if isinstance(question, str):
question = self.question_by_name(question)

if isinstance(strategy, str):
if isinstance(question, LabelQuestion):
strategy = LabelQuestionStrategy(strategy)
elif isinstance(question, MultiLabelQuestion):
strategy = MultiLabelQuestionStrategy(strategy)
elif isinstance(question, RatingQuestion):
strategy = RatingQuestionStrategy(strategy)
elif isinstance(question, RankingQuestion):
strategy = RankingQuestionStrategy(strategy)
elif isinstance(question, TextQuestion):
strategy = TextQuestionStrategy(strategy)
else:
raise ValueError(f"Question {question} is not supported yet")

strategy.unify_responses(self.records, question)
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datasets import Dataset, DatasetDict

from argilla.client.feedback.training.base import ArgillaTrainerSkeleton
from argilla.client.feedback.training.schemas import TrainingTaskForTextClassification
from argilla.client.feedback.training.schemas import TrainingTaskForQuestionAnswering, TrainingTaskForTextClassification
from argilla.training.transformers import ArgillaTransformersTrainer as ArgillaTransformersTrainerV1


Expand All @@ -27,6 +27,7 @@ def __init__(self, *args, **kwargs):

import torch
from transformers import (
AutoModelForQuestionAnswering,
AutoModelForSequenceClassification,
set_seed,
)
Expand Down Expand Up @@ -59,6 +60,8 @@ def __init__(self, *args, **kwargs):

if isinstance(self._task, TrainingTaskForTextClassification):
self._model_class = AutoModelForSequenceClassification
elif isinstance(self._task, TrainingTaskForQuestionAnswering):
self._model_class = AutoModelForQuestionAnswering
else:
raise NotImplementedError(
f"ArgillaTransformersTrainer does not support {self._task.__class__.__name__} yet."
Expand Down
Loading