diff --git a/.gitignore b/.gitignore index bc4bf01665..063c3d52c7 100644 --- a/.gitignore +++ b/.gitignore @@ -155,4 +155,5 @@ kinetics movie_posters CameraRGB CameraSeg +jigsaw_toxic_comments flash_examples/serve/tabular_classification/data diff --git a/CHANGELOG.md b/CHANGELOG.md index a16eaaa6b2..3f4a050f76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.jit` to tasks where possible and documented task JIT compatibility ([#389](https://github.com/PyTorchLightning/lightning-flash/pull/389)) - Added option to provide a `Sampler` to the `DataModule` to use when creating a `DataLoader` ([#390](https://github.com/PyTorchLightning/lightning-flash/pull/390)) +- Added support for multi-label text classification and toxic comments example ([#401](https://github.com/PyTorchLightning/lightning-flash/pull/401)) ### Changed diff --git a/docs/source/index.rst b/docs/source/index.rst index f712291dd8..553811acab 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,10 +22,11 @@ Lightning Flash reference/task reference/image_classification + reference/image_classification_multi_label reference/image_embedder - reference/multi_label_classification reference/summarization reference/text_classification + reference/text_classification_multi_label reference/tabular_classification reference/translation reference/object_detection diff --git a/docs/source/reference/multi_label_classification.rst b/docs/source/reference/image_classification_multi_label.rst similarity index 98% rename from docs/source/reference/multi_label_classification.rst rename to docs/source/reference/image_classification_multi_label.rst index 7b75bb7ada..e245fe7ff4 100644 --- a/docs/source/reference/multi_label_classification.rst +++ b/docs/source/reference/image_classification_multi_label.rst @@ -1,5 +1,5 @@ -.. _multi_label_classification: +.. _image_classification_multi_label: ################################ Multi-label Image Classification diff --git a/docs/source/reference/text_classification_multi_label.rst b/docs/source/reference/text_classification_multi_label.rst new file mode 100644 index 0000000000..72bf2f271e --- /dev/null +++ b/docs/source/reference/text_classification_multi_label.rst @@ -0,0 +1,57 @@ +.. _text_classification_multi_label: + +############################### +Multi-label Text Classification +############################### + +******** +The task +******** + +Multi-label classification is the task of assigning a number of labels from a fixed set to each data point, which can be in any modality. +In this example, we will look at the task of classifying comment toxicity. + +----- + +******** +The data +******** +The data we will use in this example is from the kaggle toxic comment classification challenge by jigsaw: `www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge `_. + +------ + +********* +Inference +********* + +We can load a pretrained :class:`~flash.text.classification.model.TextClassifier` and perform inference on any string sequence using :func:`~flash.text.classification.model.TextClassifier.predict`: + +.. literalinclude:: ../../../flash_examples/predict/text_classification_multi_label.py + :language: python + :lines: 14- + +For more advanced inference options, see :ref:`predictions`. + +----- + +********** +Finetuning +********** + +Now let's look at how we can finetune a model on the toxic comments data. +Once we download the data using :func:`~flash.core.data.download_data`, we can create our :meth:`~flash.text.classification.data.TextClassificationData` using :meth:`~flash.core.data.data_module.DataModule.from_csv`. +The backbone can be any BERT classification model from Huggingface. +We use ``"unitary/toxic-bert"`` as the backbone since it's already trained on the toxic comments data. +Now all we need to do is fine-tune our model! + +.. literalinclude:: ../../../flash_examples/finetuning/text_classification_multi_label.py + :language: python + :lines: 14- + +---- + +To run the example: + +.. code-block:: bash + + python flash_examples/finetuning/text_classification_multi_label.py diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 4ed185e93b..7a21498608 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -357,11 +357,15 @@ def _split_train_val( "`val_split` should be `None` when the dataset is built with an IterableDataset." ) - train_num_samples = len(train_dataset) - val_num_samples = int(train_num_samples * val_split) - val_indices = list(np.random.choice(range(train_num_samples), val_num_samples, replace=False)) - train_indices = [i for i in range(train_num_samples) if i not in val_indices] - return SplitDataset(train_dataset, train_indices), SplitDataset(train_dataset, val_indices) + val_num_samples = int(len(train_dataset) * val_split) + indices = list(range(len(train_dataset))) + np.random.shuffle(indices) + val_indices = indices[:val_num_samples] + train_indices = indices[val_num_samples:] + return ( + SplitDataset(train_dataset, train_indices, use_duplicated_indices=True), + SplitDataset(train_dataset, val_indices, use_duplicated_indices=True), + ) @classmethod def from_data_source( diff --git a/flash/core/model.py b/flash/core/model.py index 2d3bbe6166..8664195dfe 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -146,9 +146,10 @@ def step(self, batch: Any, batch_idx: int) -> Any: x, y = batch y_hat = self(x) output = {"y_hat": y_hat} + y_hat = self.to_loss_format(output["y_hat"]) losses = {name: l_fn(y_hat, y) for name, l_fn in self.loss_fn.items()} logs = {} - y_hat = self.to_metrics_format(y_hat) + y_hat = self.to_metrics_format(output["y_hat"]) for name, metric in self.metrics.items(): if isinstance(metric, torchmetrics.metric.Metric): metric(y_hat, y) @@ -164,6 +165,9 @@ def step(self, batch: Any, batch_idx: int) -> Any: output["y"] = y return output + def to_loss_format(self, x: torch.Tensor) -> torch.Tensor: + return x + def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: return x diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 541a887733..543ee24b2b 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -83,20 +83,24 @@ def __init__(self, filetype: str, backbone: str, max_length: int = 128): self.filetype = filetype + def _multilabel_target(self, targets, element): + targets = list(element.pop(target) for target in targets) + element["labels"] = targets + return element + def load_data( self, data: Tuple[str, Union[str, List[str]], Union[str, List[str]]], dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: - csv_file, input, target = data + file, input, target = data data_files = {} stage = self.running_stage.value - data_files[stage] = str(csv_file) + data_files[stage] = str(file) - # FLASH_TESTING is set in the CI to run faster. # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING and not torch.cuda.is_available(): try: @@ -108,26 +112,31 @@ def load_data( else: dataset_dict = load_dataset(self.filetype, data_files=data_files) - if self.training: - labels = list(sorted(list(set(dataset_dict[stage][target])))) - dataset.num_classes = len(labels) - self.set_state(LabelsState(labels)) - - labels = self.get_state(LabelsState) - - # convert labels to ids - # if not self.predicting: - if labels is not None: - labels = labels.labels - label_to_class_mapping = {v: k for k, v in enumerate(labels)} - dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target)) + if not self.predicting: + if isinstance(target, List): + # multi-target + dataset_dict = dataset_dict.map(partial(self._multilabel_target, target)) + dataset.num_classes = len(target) + self.set_state(LabelsState(target)) + else: + if self.training: + labels = list(sorted(list(set(dataset_dict[stage][target])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + # convert labels to ids + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + dataset_dict = dataset_dict.map(partial(self._transform_label, label_to_class_mapping, target)) + + # Hugging Face models expect target to be named ``labels``. + if target != "labels": + dataset_dict.rename_column_(target, "labels") dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input), batched=True) - - # Hugging Face models expect target to be named ``labels``. - if not self.predicting and target != "labels": - dataset_dict.rename_column_(target, "labels") - dataset_dict.set_format("torch", columns=columns) return dataset_dict[stage] diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 9ae993f8a2..ec6edd1c93 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -17,7 +17,7 @@ import torch -from flash.core.classification import ClassificationTask +from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer from flash.core.utilities.imports import _TEXT_AVAILABLE @@ -43,6 +43,7 @@ def __init__( self, num_classes: int, backbone: str = "prajjwal1/bert-medium", + loss_fn: Optional[Callable] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[Callable, Mapping, Sequence, None] = None, learning_rate: float = 1e-2, @@ -62,12 +63,12 @@ def __init__( super().__init__( model=None, - loss_fn=None, + loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer, + serializer=serializer or Labels(multi_label=multi_label), ) self.model = BertForSequenceClassification.from_pretrained(backbone, num_labels=num_classes) @@ -78,49 +79,32 @@ def backbone(self): # see huggingface's BertForSequenceClassification return self.model.bert - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None - ): - return self.model( - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - labels=labels, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict - ) + def forward(self, batch: Dict[str, torch.Tensor]): + return self.model(input_ids=batch.get("input_ids", None), attention_mask=batch.get("attention_mask", None)) + + def to_loss_format(self, x) -> torch.Tensor: + if isinstance(x, SequenceClassifierOutput): + x = x.logits + return super().to_loss_format(x) + + def to_metrics_format(self, x) -> torch.Tensor: + if isinstance(x, SequenceClassifierOutput): + x = x.logits + return super().to_metrics_format(x) def step(self, batch, batch_idx) -> dict: - output = {} - out = self.forward(**batch) - loss, logits = out[:2] - output["loss"] = loss - output["y_hat"] = logits - if isinstance(logits, SequenceClassifierOutput): - logits = logits.logits - probs = torch.softmax(logits, 1) - output["logs"] = {name: metric(probs, batch["labels"]) for name, metric in self.metrics.items()} - return output + target = batch.pop("labels") + batch = (batch, target) + return super().step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - return self(**batch) + return self(batch) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """ This function is used only for debugging usage with CI """ - assert history[-1]["val_accuracy"] > 0.730 + if self.hparams.multi_label: + assert history[-1]["val_f1"] > 0.45 + else: + assert history[-1]["val_accuracy"] > 0.73 diff --git a/flash_examples/finetuning/image_classification_multi_label.py b/flash_examples/finetuning/image_classification_multi_label.py index 06d6ad2f35..79c24b8cc1 100644 --- a/flash_examples/finetuning/image_classification_multi_label.py +++ b/flash_examples/finetuning/image_classification_multi_label.py @@ -21,7 +21,6 @@ from flash.core.classification import Labels from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier -from flash.image.classification.data import ImageClassificationPreprocess # 1. Download the data # This is a subset of the movie poster genre prediction data set from the paper diff --git a/flash_examples/finetuning/text_classification_multi_label.py b/flash_examples/finetuning/text_classification_multi_label.py new file mode 100644 index 0000000000..12dfb786ee --- /dev/null +++ b/flash_examples/finetuning/text_classification_multi_label.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics import F1 + +import flash +from flash.core.data.utils import download_data +from flash.text import TextClassificationData, TextClassifier + +# 1. Download the data from the Kaggle Toxic Comment Classification Challenge: +# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge +download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "data/") + +# 2. Load the data +datamodule = TextClassificationData.from_csv( + "comment_text", + ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], + train_file="data/jigsaw_toxic_comments/train.csv", + test_file="data/jigsaw_toxic_comments/test.csv", + predict_file="data/jigsaw_toxic_comments/predict.csv", + batch_size=16, + val_split=0.1, + backbone="unitary/toxic-bert", +) + +# 3. Build the model +model = TextClassifier( + num_classes=datamodule.num_classes, + multi_label=True, + metrics=F1(num_classes=datamodule.num_classes), + backbone="unitary/toxic-bert", +) + +# 4. Create the trainer +trainer = flash.Trainer(fast_dev_run=True) + +# 5. Fine-tune the model +trainer.finetune(model, datamodule=datamodule, strategy="freeze") + +# 6. Generate predictions for a few comments! +predictions = model.predict([ + "No, he is an arrogant, self serving, immature idiot. Get it right.", + "U SUCK HANNAH MONTANA", + "Would you care to vote? Thx.", +]) +print(predictions) + +# 7. Save it! +trainer.save_checkpoint("text_classification_multi_label_model.pt") diff --git a/flash_examples/predict/text_classification_multi_label.py b/flash_examples/predict/text_classification_multi_label.py new file mode 100644 index 0000000000..de42d31ffe --- /dev/null +++ b/flash_examples/predict/text_classification_multi_label.py @@ -0,0 +1,42 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from pytorch_lightning import Trainer + +from flash.core.data.utils import download_data +from flash.text import TextClassificationData, TextClassifier + +# 1. Download the data from the Kaggle Toxic Comment Classification Challenge: +# https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge +download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "data/") + +# 2. Load the model from a checkpoint +model = TextClassifier.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/text_classification_multi_label_model.pt" +) + +# 2a. Classify a few sentences! How was the movie? +predictions = model.predict([ + "No, he is an arrogant, self serving, immature idiot. Get it right.", + "U SUCK HANNAH MONTANA", + "Would you care to vote? Thx.", +]) +print(predictions) + +# 2b. Or generate predictions from a whole file! +datamodule = TextClassificationData.from_csv( + "comment_text", + predict_file="data/jigsaw_toxic_comments/predict.csv", +) +predictions = Trainer().predict(model, datamodule=datamodule) +print(predictions) diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index c811fdfa34..b628086657 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -51,7 +51,7 @@ def test_init_train(tmpdir): @pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") def test_jit(tmpdir): - sample_input = torch.randint(1000, size=(1, 100)) + sample_input = {"input_ids": torch.randint(1000, size=(1, 100))} path = os.path.join(tmpdir, "test.pt") model = TextClassifier(2, TEST_BACKBONE)