Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Multi-label text classification (#401)
Browse files Browse the repository at this point in the history
* Add toxic comments example

* Updates

* Clean

* Add docs

* Add docs

* Update CHANGELOG.md

* Fix test

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
ethanwharris and mergify[bot] authored Jun 11, 2021
1 parent bec7142 commit d23319d
Show file tree
Hide file tree
Showing 13 changed files with 232 additions and 71 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,5 @@ kinetics
movie_posters
CameraRGB
CameraSeg
jigsaw_toxic_comments
flash_examples/serve/tabular_classification/data
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389))
- Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390))
- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401))

### Changed

Expand Down
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ Lightning Flash

reference/task
reference/image_classification
reference/image_classification_multi_label
reference/image_embedder
reference/multi_label_classification
reference/summarization
reference/text_classification
reference/text_classification_multi_label
reference/tabular_classification
reference/translation
reference/object_detection
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

.. _multi_label_classification:
.. _image_classification_multi_label:

################################
Multi-label Image Classification
Expand Down
57 changes: 57 additions & 0 deletions docs/source/reference/text_classification_multi_label.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
.. _text_classification_multi_label:

###############################
Multi-label Text Classification
###############################

********
The task
********

Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality.
In this example, we will look at the task of classifying comment toxicity.

-----

********
The data
********
The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: `www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge <https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge>`_.

------

*********
Inference
*********

We can load a pretrained :class:`~flash.text.classification.model.TextClassifier` and perform inference on any string sequence using :func:`~flash.text.classification.model.TextClassifier.predict`:

.. literalinclude:: ../../../flash_examples/predict/text_classification_multi_label.py
:language: python
:lines: 14-

For more advanced inference options, see :ref:`predictions`.

-----

**********
Finetuning
**********

Now let's look at how we can finetune a model on the toxic comments data.
Once we download the data using :func:`~flash.core.data.download_data`, we can create our :meth:`~flash.text.classification.data.TextClassificationData` using :meth:`~flash.core.data.data_module.DataModule.from_csv`.
The backbone can be any BERT classification model from Huggingface.
We use ``"unitary/toxic-bert"`` as the backbone since it's already trained on the toxic comments data.
Now all we need to do is fine-tune our model!

.. literalinclude:: ../../../flash_examples/finetuning/text_classification_multi_label.py
:language: python
:lines: 14-

----

To run the example:

.. code-block:: bash
python flash_examples/finetuning/text_classification_multi_label.py
14 changes: 9 additions & 5 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,15 @@ def _split_train_val(
"`val_split` should be `None` when the dataset is built with an IterableDataset."
)

train_num_samples = len(train_dataset)
val_num_samples = int(train_num_samples * val_split)
val_indices = list(np.random.choice(range(train_num_samples), val_num_samples, replace=False))
train_indices = [i for i in range(train_num_samples) if i not in val_indices]
return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices)
val_num_samples = int(len(train_dataset) * val_split)
indices = list(range(len(train_dataset)))
np.random.shuffle(indices)
val_indices = indices[:val_num_samples]
train_indices = indices[val_num_samples:]
return (
SplitDataset(train_dataset, train_indices, use_duplicated_indices=True),
SplitDataset(train_dataset, val_indices, use_duplicated_indices=True),
)

@classmethod
def from_data_source(
Expand Down
6 changes: 5 additions & 1 deletion flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,10 @@ def step(self, batch: Any, batch_idx: int) -> Any:
x, y = batch
y_hat = self(x)
output = {"y_hat": y_hat}
y_hat = self.to_loss_format(output["y_hat"])
losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()}
logs = {}
y_hat = self.to_metrics_format(y_hat)
y_hat = self.to_metrics_format(output["y_hat"])
for name, metric in self.metrics.items():
if isinstance(metric, torchmetrics.metric.Metric):
metric(y_hat, y)
Expand All @@ -164,6 +165,9 @@ def step(self, batch: Any, batch_idx: int) -> Any:
output["y"] = y
return output

def to_loss_format(self, x: torch.Tensor) -> torch.Tensor:
return x

def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor:
return x

Expand Down
51 changes: 30 additions & 21 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,24 @@ def __init__(self, filetype: str, backbone: str, max_length: int = 128):

self.filetype = filetype

def _multilabel_target(self, targets, element):
targets = list(element.pop(target) for target in targets)
element["labels"] = targets
return element

def load_data(
self,
data: Tuple[str, Union[str, List[str]], Union[str, List[str]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
csv_file, input, target = data
file, input, target = data

data_files = {}

stage = self.running_stage.value
data_files[stage] = str(csv_file)
data_files[stage] = str(file)

# FLASH_TESTING is set in the CI to run faster.
# FLASH_TESTING is set in the CI to run faster.
if flash._IS_TESTING and not torch.cuda.is_available():
try:
Expand All @@ -108,26 +112,31 @@ def load_data(
else:
dataset_dict = load_dataset(self.filetype, data_files=data_files)

if self.training:
labels = list(sorted(list(set(dataset_dict[stage][target]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
# if not self.predicting:
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target))
if not self.predicting:
if isinstance(target, List):
# multi-target
dataset_dict = dataset_dict.map(partial(self._multilabel_target, target))
dataset.num_classes = len(target)
self.set_state(LabelsState(target))
else:
if self.training:
labels = list(sorted(list(set(dataset_dict[stage][target]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target))

# Hugging Face models expect target to be named ``labels``.
if target != "labels":
dataset_dict.rename_column_(target, "labels")

dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input), batched=True)

# Hugging Face models expect target to be named ``labels``.
if not self.predicting and target != "labels":
dataset_dict.rename_column_(target, "labels")

dataset_dict.set_format("torch", columns=columns)

return dataset_dict[stage]
Expand Down
64 changes: 24 additions & 40 deletions flash/text/classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from flash.core.classification import ClassificationTask
from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.utilities.imports import _TEXT_AVAILABLE

Expand All @@ -43,6 +43,7 @@ def __init__(
self,
num_classes: int,
backbone: str = "prajjwal1/bert-medium",
loss_fn: Optional[Callable] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[Callable, Mapping, Sequence, None] = None,
learning_rate: float = 1e-2,
Expand All @@ -62,12 +63,12 @@ def __init__(

super().__init__(
model=None,
loss_fn=None,
loss_fn=loss_fn,
optimizer=optimizer,
metrics=metrics,
learning_rate=learning_rate,
multi_label=multi_label,
serializer=serializer,
serializer=serializer or Labels(multi_label=multi_label),
)
self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes)

Expand All @@ -78,49 +79,32 @@ def backbone(self):
# see huggingface's BertForSequenceClassification
return self.model.bert

def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None
):
return self.model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
def forward(self, batch: Dict[str, torch.Tensor]):
return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None))

def to_loss_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
x = x.logits
return super().to_loss_format(x)

def to_metrics_format(self, x) -> torch.Tensor:
if isinstance(x, SequenceClassifierOutput):
x = x.logits
return super().to_metrics_format(x)

def step(self, batch, batch_idx) -> dict:
output = {}
out = self.forward(**batch)
loss, logits = out[:2]
output["loss"] = loss
output["y_hat"] = logits
if isinstance(logits, SequenceClassifierOutput):
logits = logits.logits
probs = torch.softmax(logits, 1)
output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()}
return output
target = batch.pop("labels")
batch = (batch, target)
return super().step(batch, batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
return self(**batch)
return self(batch)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
"""
assert history[-1]["val_accuracy"] > 0.730
if self.hparams.multi_label:
assert history[-1]["val_f1"] > 0.45
else:
assert history[-1]["val_accuracy"] > 0.73
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from flash.core.classification import Labels
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageClassifier
from flash.image.classification.data import ImageClassificationPreprocess

# 1. Download the data
# This is a subset of the movie poster genre prediction data set from the paper
Expand Down
59 changes: 59 additions & 0 deletions flash_examples/finetuning/text_classification_multi_label.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics import F1

import flash
from flash.core.data.utils import download_data
from flash.text import TextClassificationData, TextClassifier

# 1. Download the data from the Kaggle Toxic Comment Classification Challenge:
# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "data/")

# 2. Load the data
datamodule = TextClassificationData.from_csv(
"comment_text",
["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
train_file="data/jigsaw_toxic_comments/train.csv",
test_file="data/jigsaw_toxic_comments/test.csv",
predict_file="data/jigsaw_toxic_comments/predict.csv",
batch_size=16,
val_split=0.1,
backbone="unitary/toxic-bert",
)

# 3. Build the model
model = TextClassifier(
num_classes=datamodule.num_classes,
multi_label=True,
metrics=F1(num_classes=datamodule.num_classes),
backbone="unitary/toxic-bert",
)

# 4. Create the trainer
trainer = flash.Trainer(fast_dev_run=True)

# 5. Fine-tune the model
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 6. Generate predictions for a few comments!
predictions = model.predict([
"No, he is an arrogant, self serving, immature idiot. Get it right.",
"U SUCK HANNAH MONTANA",
"Would you care to vote? Thx.",
])
print(predictions)

# 7. Save it!
trainer.save_checkpoint("text_classification_multi_label_model.pt")
Loading

0 comments on commit d23319d

Please sign in to comment.