From 0446361a923267f74d91e9b1983409bcc98fa203 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Tue, 2 Feb 2021 14:52:40 +0000 Subject: [PATCH] Added Seq2Seq tasks (#37) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added Seq2Seq tasks * Use rank 0 for model specific params * Add licences * Fix summarization scripts * Fix comments, update from files API * Add tests * Add docs * Fix doc header * Apply suggestions from code review Co-authored-by: Carlos Mocholí * Add typing * Add imports, fix docs * Add rouge score for metric * fix imports * fix imports and style * Install sentencepiece for slow tokenizer conversion * yapf * Fixed underlines * Fixed doc references * Added min versions address formatting * Update requirement * Fix formatting issues * add seq to seq finetuning callback * docs: link blog * resolve tests * update * Delete lock file * remove download_model * Revert some changes, update requirements.txt * Move to mbart for now, even if it's a large model file * Clean up finetuning module, fix tests plus add todo * Cleanup * Update flash/text/seq2seq/core/model.py Co-authored-by: Jirka Borovec * Remove lock file, add typing * Change to test code * Swap to module available * Revert testcode due to test error Co-authored-by: Carlos Mocholí Co-authored-by: Jirka Borovec Co-authored-by: tchaton Co-authored-by: Jirka Borovec --- docs/source/_templates/theme_variables.jinja | 2 +- docs/source/index.rst | 2 + docs/source/reference/summarization.rst | 185 ++++++++ docs/source/reference/text_classification.rst | 12 +- docs/source/reference/translation.rst | 167 +++++++ flash/core/finetuning.py | 26 +- flash/core/model.py | 10 +- flash/core/trainer.py | 8 +- flash/text/__init__.py | 8 + flash/text/seq2seq/__init__.py | 3 + flash/text/seq2seq/core/__init__.py | 3 + flash/text/seq2seq/core/data.py | 275 +++++++++++ flash/text/seq2seq/core/finetuning.py | 35 ++ flash/text/seq2seq/core/model.py | 129 +++++ flash/text/seq2seq/summarization/__init__.py | 2 + flash/text/seq2seq/summarization/data.py | 132 ++++++ flash/text/seq2seq/summarization/metric.py | 109 +++++ flash/text/seq2seq/summarization/model.py | 59 +++ flash/text/seq2seq/summarization/utils.py | 30 ++ flash/text/seq2seq/translation/__init__.py | 2 + flash/text/seq2seq/translation/data.py | 134 ++++++ flash/text/seq2seq/translation/metric.py | 122 +++++ flash/text/seq2seq/translation/model.py | 73 +++ flash/vision/classification/model.py | 4 +- flash_examples/finetuning/summarization.py | 31 ++ flash_examples/finetuning/translation.py | 31 ++ flash_examples/predict/summarize.py | 44 ++ flash_examples/predict/translate.py | 27 ++ flash_notebooks/generic_task.ipynb | 445 ++---------------- flash_notebooks/image_classification.ipynb | 107 ++++- flash_notebooks/text_classification.ipynb | 114 ++++- requirements.txt | 5 +- tests/core/test_finetuning.py | 24 +- tests/core/test_model.py | 17 +- tests/core/test_trainer.py | 3 +- tests/text/summarization/__init__.py | 0 tests/text/summarization/test_data.py | 79 ++++ tests/text/summarization/test_model.py | 35 ++ tests/text/translation/__init__.py | 0 tests/text/translation/test_data.py | 79 ++++ tests/text/translation/test_model.py | 35 ++ tests/vision/classification/test_model.py | 3 +- 42 files changed, 2130 insertions(+), 481 deletions(-) create mode 100644 docs/source/reference/summarization.rst create mode 100644 docs/source/reference/translation.rst create mode 100644 flash/text/seq2seq/__init__.py create mode 100644 flash/text/seq2seq/core/__init__.py create mode 100644 flash/text/seq2seq/core/data.py create mode 100644 flash/text/seq2seq/core/finetuning.py create mode 100644 flash/text/seq2seq/core/model.py create mode 100644 flash/text/seq2seq/summarization/__init__.py create mode 100644 flash/text/seq2seq/summarization/data.py create mode 100644 flash/text/seq2seq/summarization/metric.py create mode 100644 flash/text/seq2seq/summarization/model.py create mode 100644 flash/text/seq2seq/summarization/utils.py create mode 100644 flash/text/seq2seq/translation/__init__.py create mode 100644 flash/text/seq2seq/translation/data.py create mode 100644 flash/text/seq2seq/translation/metric.py create mode 100644 flash/text/seq2seq/translation/model.py create mode 100644 flash_examples/finetuning/summarization.py create mode 100644 flash_examples/finetuning/translation.py create mode 100644 flash_examples/predict/summarize.py create mode 100644 flash_examples/predict/translate.py create mode 100644 tests/text/summarization/__init__.py create mode 100644 tests/text/summarization/test_data.py create mode 100644 tests/text/summarization/test_model.py create mode 100644 tests/text/translation/__init__.py create mode 100644 tests/text/translation/test_data.py create mode 100644 tests/text/translation/test_model.py diff --git a/docs/source/_templates/theme_variables.jinja b/docs/source/_templates/theme_variables.jinja index 9143c852d7..cd4c1fe7e4 100644 --- a/docs/source/_templates/theme_variables.jinja +++ b/docs/source/_templates/theme_variables.jinja @@ -11,7 +11,7 @@ 'home': 'https://pytorchlightning.github.io/lightning-flash/', 'get_started': 'https://pytorchlightning.github.io/lightning-flash/quickstart.html', 'features': 'https://pytorchlightning.github.io/lightning-flash/', - 'blog': 'https://pytorchlightning.github.io/lightning-flash/', + 'blog': 'https://www.pytorchlightning.ai/blog', 'resources': 'https://pytorchlightning.github.io/lightning-flash/', 'support': 'https://pytorchlightning.github.io/lightning-flash/', } diff --git a/docs/source/index.rst b/docs/source/index.rst index cf06400c16..d9eaa37540 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -22,8 +22,10 @@ Lightning Flash reference/task reference/image_classification reference/image_embedder + reference/summarization reference/text_classification reference/tabular_classification + reference/translation .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst new file mode 100644 index 0000000000..8931f6d850 --- /dev/null +++ b/docs/source/reference/summarization.rst @@ -0,0 +1,185 @@ +.. _summarization: + +############# +Summarization +############# + +******** +The task +******** + +Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. +This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model. + +----- + +********* +Inference +********* + +The :class:`~flash.text.SummarizationTask` is already pre-trained on [XSUM](https://arxiv.org/abs/1808.08745), a dataset of online British Broadcasting Corporation articles. + +Use the :class:`~flash.text.SummarizationTask` pretrained model for inference on any string sequence using :func:`~flash.text.SummarizationTask.predict`: + +.. code-block:: python + + # import our libraries + from flash.text import SummarizationTask + + + # 2. Load the model from a checkpoint + model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") + + # 2. Perform inference from a sequence + predictions = model.predict([ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ + ]) + print(predictions) + +Or on a given dataset: + +.. code-block:: python + + # import our libraries + from pytorch_lightning import Trainer + from flash import download_data + from flash.text import SummarizationData, SummarizationTask + + # 2. Load the model from a checkpoint + model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") + + # 3. Create dataset from file + datamodule = SummarizationData.from_file( + predict_file="data/xsum/predict.csv", + input="input", + ) + + # 4. generate summaries + predictions = Trainer().predict(model, datamodule=datamodule) + print(predictions) + +For more advanced inference options, see :ref:`predictions`. + +----- + +********** +Finetuning +********** + +Say you want to finetune to your own summarization data. We use the XSUM dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: + +.. code-block:: + + input,target + "The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees." + "Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005. + ... + +In the above the input column represents the long articles/documents, and the target is the short description used as the target. + +All we need is three lines of code to train our model! + +.. code-block:: python + + # import our libraries + import flash + from flash import download_data + from flash.text import SummarizationData, SummarizationTask + + # 1. Download data + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') + + # Organize the data + datamodule = SummarizationData.from_files( + train_file="data/xsum/train.csv", + valid_file="data/xsum/valid.csv", + test_file="data/xsum/test.csv", + input="input", + target="target" + ) + + # 2. Build the task + model = SummarizationTask() + + # 4. Create trainer + trainer = flash.Trainer(max_epochs=1, gpus=1) + + # 5. Finetune the task + trainer.finetune(model, datamodule=datamodule) + + # 6. Save trainer task + trainer.save_checkpoint("summarization_model_xsum.pt") + +---- + +To run the example: + +.. code-block:: bash + + python flash_examples/finetuning/summarization.py + + +------ + +********************* +Changing the backbone +********************* +By default, we use the `t5 `_ model for summarization. You can change the model run by passing in the backbone parameter. + +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. + +.. code-block:: python + + datamodule = SummarizationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + backbone="google/mt5-small", + ) + + model = SummarizationTask(backbone="google/mt5-small") + +------ + +************* +API reference +************* + +.. _summarization_task: + +SummarizationTask +----------------- + +.. autoclass:: flash.text.seq2seq.summarization.model.SummarizationTask + :members: + :exclude-members: forward + +.. _summarization_data: + +SummarizationData +----------------- + +.. autoclass:: flash.text.seq2seq.summarization.data.SummarizationData + +.. automethod:: flash.text.seq2seq.summarization.data.SummarizationData.from_files diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index b87a21cf8e..31b5f9ffcb 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -16,9 +16,9 @@ Text classification is the task of assigning a piece of text (word, sentence or Inference ********* -The :class:`~flash.text.TextClassificatier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment. +The :class:`~flash.text.TextClassifier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment. -Use the :class:`~flash.text.TextClassificatier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: +Use the :class:`~flash.text.TextClassifier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: .. code-block:: python @@ -83,10 +83,10 @@ All we need is three lines of code to train our model! .. code-block:: python - # import our libraries - import flash - from flash import download_data - from flash.text import TextClassificationData, TextClassifier + # import our libraries + import flash + from flash import download_data + from flash.text import TextClassificationData, TextClassifier # 1. Download data download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", 'data/') diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst new file mode 100644 index 0000000000..96f659fe0a --- /dev/null +++ b/docs/source/reference/translation.rst @@ -0,0 +1,167 @@ +.. _translation: + +########### +Translation +########### + +******** +The task +******** + +Translation is the task of translating text from a source language to another, such as English to Romanian. +This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the English text would be our input sequence, and the Romanian sentence would be the output sequence from the model. + +----- + +********* +Inference +********* + +The :class:`~flash.text.TranslationTask` is already pre-trained on [WMT16 English/Romanian](https://www.statmt.org/wmt16/translation-task.html), a dataset of English to Romanian samples, based on the Europarl corpora. + +Use the :class:`~flash.text.TranslationTask` pretrained model for inference on any string sequence using :func:`~flash.text.TranslationTask.predict`: + +.. code-block:: python + + # import our libraries + from flash.text import TranslationTask + + + # 2. Load the model from a checkpoint + model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") + + # 2. Perform inference from list of sequences + predictions = model.predict([ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + ]) + print(predictions) + +Or on a given dataset: + +.. code-block:: python + + # import our libraries + from pytorch_lightning import Trainer + from flash import download_data + from flash.text import TranslationData, TranslationTask + + # 2. Load the model from a checkpoint + model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") + + # 3. Create dataset from file + datamodule = TranslationData.from_file( + predict_file="data/wmt_en_ro/predict.csv", + input="input", + ) + + # 4. generate translations + predictions = Trainer().predict(model, datamodule=datamodule) + print(predictions) + +For more advanced inference options, see :ref:`predictions`. + +----- + +********** +Finetuning +********** + +Say you want to finetune to your own translation data. We use the English/Romanian WMT16 dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: + +.. code-block:: + + input,target + "Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal" + "Closure of sitting","Ridicarea şedinţei" + ... + +In the above the input/target columns represent the English and Romanian translation respectively. + +All we need is three lines of code to train our model! + +.. code-block:: python + + # import our libraries + import flash + from flash import download_data + from flash.text import TranslationData, TranslationTask + + # 1. Download data + download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') + + # Organize the data + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + ) + + # 2. Build the task + model = TranslationTask() + + # 4. Create trainer + trainer = flash.Trainer(max_epochs=5, gpus=1, precision=16) + + # 5. Finetune the task + trainer.finetune(model, datamodule=datamodule) + + # 6. Save trainer task + trainer.save_checkpoint("translation_model_en_ro.pt") + +---- + +To run the example: + +.. code-block:: bash + + python flash_examples/finetuning/translation.py + + +------ + +********************* +Changing the backbone +********************* +By default, we use `mBART `_ model for translation. You can change the model run by passing in the backbone parameter. + +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. + +.. code-block:: python + + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + backbone="t5-small", + ) + + model = TranslationTask(backbone="t5-small") + +------ + +************* +API reference +************* + +.. _translation_task: + +TranslationTask +--------------- + +.. autoclass:: flash.text.seq2seq.translation.model.TranslationTask + :members: + :exclude-members: forward + +.. _translation_data: + +TranslationData +--------------- + +.. autoclass:: flash.text.seq2seq.translation.data.TranslationData + +.. automethod:: flash.text.seq2seq.translation.data.TranslationData.from_files diff --git a/flash/core/finetuning.py b/flash/core/finetuning.py index 5c365cb18d..8e402dff56 100644 --- a/flash/core/finetuning.py +++ b/flash/core/finetuning.py @@ -21,7 +21,18 @@ class NoFreeze(BaseFinetuning): - pass + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + pass + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + pass class FlashBaseFinetuning(BaseFinetuning): @@ -53,9 +64,20 @@ def freeze_using_attr_names(self, pl_module, attr_names: List[str], train_bn: bo MisconfigurationException(f"Your model must have a {attr} attribute") self.freeze(module=attr, train_bn=train_bn) + def finetunning_function(self, pl_module: pl.LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + pass + class Freeze(FlashBaseFinetuning): - pass + + def finetunning_function( + self, + pl_module: pl.LightningModule, + epoch: int, + optimizer: Optimizer, + opt_idx: int, + ) -> None: + pass class FreezeUnfreeze(FlashBaseFinetuning): diff --git a/flash/core/model.py b/flash/core/model.py index 7c4e2a2f1b..ea11d1bd61 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, IO, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl import torch from torch import nn -from flash.core.data import DataModule, DataPipeline +from flash.core.data import DataModule, DataPipeline, download_data from flash.core.utils import get_callable_dict @@ -49,7 +49,7 @@ class Task(pl.LightningModule): loss_fn: Loss function for training optimizer: Optimizer to use for training, defaults to `torch.optim.SGD`. metrics: Metrics to compute for training and evaluation. - learning_rate: Learning rate to use for training, defaults to `1e-3` + learning_rate: Learning rate to use for training, defaults to `5e-5` """ def __init__( @@ -58,7 +58,7 @@ def __init__( loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 1e-3, + learning_rate: float = 5e-5, ): super().__init__() if model is not None: @@ -143,7 +143,7 @@ def predict( """ data_pipeline = data_pipeline or self.data_pipeline batch = x if skip_collate_fn else data_pipeline.collate_fn(x) - batch_x, batch_y = batch if len(batch) == 2 else (batch, None) + batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None) predictions = self.forward(batch_x) output = data_pipeline.uncollate_fn(predictions) # TODO: pass batch and x return output diff --git a/flash/core/trainer.py b/flash/core/trainer.py index e570d4ae2d..d58a54b9df 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -100,7 +100,7 @@ def _resolve_callbacks(self, model, strategy): ) if isinstance(strategy, BaseFinetuning): - callback = strategy + callback = [strategy] else: # todo: change to ``configure_callbacks`` when merged to Lightning. model_callback = model.configure_finetune_callback() @@ -115,11 +115,11 @@ def _resolve_callbacks(self, model, strategy): f"The provided {strategy} will be overriden. " "HINT: Provide a `BaseFinetuning` callback as strategy to make it prioritized. ", UserWarning ) - callback = [model_callback] + callback = model_callback else: callback = instantiate_default_finetuning_callbacks(strategy) - self.callbacks = self._merge_callbacks(self.callbacks, [callback]) + self.callbacks = self._merge_callbacks(self.callbacks, callback) @staticmethod def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: @@ -127,7 +127,7 @@ def _merge_callbacks(old_callbacks: List, new_callbacks: List) -> List: This function keeps only 1 instance of each callback type, extending new_callbacks with old_callbacks """ - if len(new_callbacks): + if len(new_callbacks) == 0: return old_callbacks new_callbacks_types = set(type(c) for c in new_callbacks) old_callbacks_types = set(type(c) for c in old_callbacks) diff --git a/flash/text/__init__.py b/flash/text/__init__.py index 67e949dcb6..9f15fc6208 100644 --- a/flash/text/__init__.py +++ b/flash/text/__init__.py @@ -1 +1,9 @@ from flash.text.classification import TextClassificationData, TextClassifier +from flash.text.seq2seq import ( + Seq2SeqData, + Seq2SeqTask, + SummarizationData, + SummarizationTask, + TranslationData, + TranslationTask, +) diff --git a/flash/text/seq2seq/__init__.py b/flash/text/seq2seq/__init__.py new file mode 100644 index 0000000000..5ae1a678bc --- /dev/null +++ b/flash/text/seq2seq/__init__.py @@ -0,0 +1,3 @@ +from flash.text.seq2seq.core import Seq2SeqData, Seq2SeqFreezeEmbeddings, Seq2SeqTask +from flash.text.seq2seq.summarization import SummarizationData, SummarizationTask +from flash.text.seq2seq.translation import TranslationData, TranslationTask diff --git a/flash/text/seq2seq/core/__init__.py b/flash/text/seq2seq/core/__init__.py new file mode 100644 index 0000000000..1f43f0d348 --- /dev/null +++ b/flash/text/seq2seq/core/__init__.py @@ -0,0 +1,3 @@ +from flash.text.seq2seq.core.data import Seq2SeqData +from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings +from flash.text.seq2seq.core.model import Seq2SeqTask diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py new file mode 100644 index 0000000000..a86e2af5c4 --- /dev/null +++ b/flash/text/seq2seq/core/data.py @@ -0,0 +1,275 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +from typing import Any, Callable, Optional, Union + +from datasets import load_dataset +from datasets.utils.download_manager import GenerateMode +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from torch import Tensor +from transformers import AutoTokenizer, default_data_collator + +from flash.core.data import DataModule, TaskDataPipeline + + +def prepare_dataset( + test_file: str, + filetype: str, + pipeline: TaskDataPipeline, + train_file: Optional[str] = None, + valid_file: Optional[str] = None, + predict: bool = False +): + data_files = {} + + if train_file is not None: + data_files["train"] = train_file + if valid_file is not None: + data_files["validation"] = valid_file + if test_file is not None: + data_files["test"] = test_file + + # load the dataset + dataset_dict = load_dataset( + filetype, + data_files=data_files, + ) + + # tokenize the dataset + dataset_dict = dataset_dict.map( + pipeline._tokenize_fn, + batched=True, + ) + columns = ["input_ids", "attention_mask"] if predict else ["input_ids", "attention_mask", "labels"] + dataset_dict.set_format(columns=columns) + + train_ds = None + valid_ds = None + test_ds = None + + if "train" in dataset_dict: + train_ds = dataset_dict["train"] + + if "validation" in dataset_dict: + valid_ds = dataset_dict["validation"] + + if "test" in dataset_dict: + test_ds = dataset_dict["test"] + + return train_ds, valid_ds, test_ds + + +class Seq2SeqDataPipeline(TaskDataPipeline): + + def __init__( + self, + tokenizer, + input: str, + target: Optional[str] = None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'longest' + ): + self.tokenizer = tokenizer + self._input = input + self._target = target + self._max_target_length = max_target_length + self._max_source_length = max_source_length + self._padding = padding + self._tokenize_fn = partial( + self._tokenize_fn, + tokenizer=self.tokenizer, + input=self._input, + target=self._target, + max_source_length=self._max_source_length, + max_target_length=self._max_target_length, + padding=self._padding + ) + + def before_collate(self, samples: Any) -> Any: + """Override to apply transformations to samples""" + if isinstance(samples, (list, tuple)) and len(samples) > 0 and all(isinstance(s, str) for s in samples): + return [self._tokenize_fn({self._input: s, self._target: None}) for s in samples] + return samples + + @staticmethod + def _tokenize_fn( + ex, + tokenizer, + input: str, + target: Optional[str], + max_source_length: int, + max_target_length: int, + padding: Union[str, bool], + ) -> Callable: + output = tokenizer.prepare_seq2seq_batch( + src_texts=ex[input], + tgt_texts=ex[target] if target else None, + max_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ) + return output + + def collate(self, samples: Any) -> Tensor: + """Override to convert a set of samples to a batch""" + return default_data_collator(samples) + + def after_collate(self, batch: Any) -> Any: + return batch + + def uncollate(self, generated_tokens: Any) -> Any: + pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + pred_str = [str.strip(s) for s in pred_str] + return pred_str + + +class Seq2SeqData(DataModule): + """Data module for Seq2Seq tasks.""" + + @staticmethod + def default_pipeline(): + return Seq2SeqDataPipeline( + AutoTokenizer.from_pretrained("sshleifer/tiny-mbart", use_fast=True), + input="input", + ) + + @classmethod + def from_files( + cls, + train_file: str, + input: str = 'input', + target: Optional[str] = None, + filetype: str = "csv", + backbone: str = "sshleifer/tiny-mbart", + valid_file: Optional[str] = None, + test_file: Optional[str] = None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + batch_size: int = 32, + num_workers: Optional[int] = None, + ): + """Creates a Seq2SeqData object from files. + + Args: + train_file: Path to training data. + input: The field storing the source translation text. + target: The field storing the target translation text. + filetype: .csv or .json + backbone: tokenizer to use, can use any HuggingFace tokenizer. + valid_file: Path to validation data. + test_file: Path to test data. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 32. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + Seq2SeqData: The constructed data module. + + Examples:: + + train_df = pd.read_csv("train_data.csv") + tab_data = TabularData.from_df(train_df, target="fraud", + numerical_input=["account_value"], + categorical_input=["account_type"]) + + """ + tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + + pipeline = Seq2SeqDataPipeline( + tokenizer=tokenizer, + input=input, + target=target, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding + ) + + train_ds, valid_ds, test_ds = prepare_dataset( + train_file=train_file, valid_file=valid_file, test_file=test_file, filetype=filetype, pipeline=pipeline + ) + + datamodule = cls( + train_ds=train_ds, + valid_ds=valid_ds, + test_ds=test_ds, + batch_size=batch_size, + num_workers=num_workers, + ) + + datamodule.data_pipeline = pipeline + return datamodule + + @classmethod + def from_file( + cls, + predict_file: str, + input: str = 'input', + target: Optional[str] = None, + backbone: str = "sshleifer/tiny-mbart", + filetype: str = "csv", + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + batch_size: int = 32, + num_workers: Optional[int] = None, + ): + """Creates a TextClassificationData object from files. + + Args: + predict_file: Path to prediction input file. + input: The field storing the source translation text. + target: The field storing the target translation text. + backbone: tokenizer to use, can use any HuggingFace tokenizer. + filetype: csv or json. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 32. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + Seq2SeqData: The constructed data module. + + """ + tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + + pipeline = Seq2SeqDataPipeline( + tokenizer=tokenizer, + input=input, + target=target, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding + ) + + train_ds, valid_ds, test_ds = prepare_dataset( + test_file=predict_file, filetype=filetype, pipeline=pipeline, predict=True + ) + + datamodule = cls( + train_ds=train_ds, + valid_ds=valid_ds, + test_ds=test_ds, + batch_size=batch_size, + num_workers=num_workers, + ) + + datamodule.data_pipeline = pipeline + return datamodule diff --git a/flash/text/seq2seq/core/finetuning.py b/flash/text/seq2seq/core/finetuning.py new file mode 100644 index 0000000000..dc4c0f7c56 --- /dev/null +++ b/flash/text/seq2seq/core/finetuning.py @@ -0,0 +1,35 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytorch_lightning as pl + +from flash.core.finetuning import FlashBaseFinetuning + + +class Seq2SeqFreezeEmbeddings(FlashBaseFinetuning): + """ + Freezes the embedding layers during Seq2Seq training. + """ + + def __init__(self, model_type: str, train_bn: bool = True): + super().__init__("", train_bn) + self.model_type = model_type + + def freeze_before_training(self, pl_module: pl.LightningModule) -> None: + is_t5 = self.model_type in ["t5", "mt5"] + model = pl_module.model if is_t5 else pl_module.model.model + self.freeze(module=model.shared, train_bn=self.train_bn) + for layer in (model.encoder, model.decoder): + self.freeze(layer.embed_tokens) + if not is_t5: + self.freeze(layer.embed_positions) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py new file mode 100644 index 0000000000..5c6f6e9c48 --- /dev/null +++ b/flash/text/seq2seq/core/model.py @@ -0,0 +1,129 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import warnings +from typing import Any, Callable, List, Mapping, Optional, Sequence, Type, Union + +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities import rank_zero_info +from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerBase + +from flash.core import Task +from flash.core.finetuning import FlashBaseFinetuning +from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings + + +def _pad_tensors_to_max_len(model_cfg, tensor, max_length): + pad_token_id = model_cfg.pad_token_id if model_cfg.pad_token_id is not None else model_cfg.eos_token_id + if pad_token_id is None: + raise ValueError( + f"Make sure that either `config.pad_token_id` or `config.eos_token_id` " + f"is defined if tensor has to be padded to `max_length`={max_length}" + ) + + padded_tensor = pad_token_id * torch.ones((tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device) + padded_tensor[:, :tensor.shape[-1]] = tensor + return padded_tensor + + +class Seq2SeqTask(Task): + """General Task for Sequence2Sequence. + Args: + loss_fn: Loss function for training + optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training, defaults to `3e-4` + val_target_max_length: Maximum length of targets in validation. Defaults to `128` + num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + """ + + def __init__( + self, + backbone: str = 't5-small', + loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 5e-5, + val_target_max_length: Optional[int] = None, + num_beams: Optional[int] = None, + ): + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" + # disable HF thousand warnings + warnings.simplefilter("ignore") + # set os environ variable for multiprocesses + os.environ["PYTHONWARNINGS"] = "ignore" + super().__init__(loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate) + self.model = AutoModelForSeq2SeqLM.from_pretrained(backbone) + self.val_target_max_length = val_target_max_length + self.num_beams = num_beams + self._initialize_model_specific_parameters() + + def forward(self, x: Any) -> Any: + max_length = self.val_target_max_length if self.val_target_max_length else self.model.config.max_length + num_beams = self.num_beams if self.num_beams else self.model.config.num_beams + generated_tokens = self.model.generate( + input_ids=x['input_ids'], attention_mask=x['attention_mask'], max_length=max_length, num_beams=num_beams + ) + # in case the batch is shorter than max length, the output should be padded + if generated_tokens.shape[-1] < max_length: + generated_tokens = _pad_tensors_to_max_len( + model_cfg=self.model.config, tensor=generated_tokens, max_length=max_length + ) + return generated_tokens + + def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor: + outputs = self.model(**batch) + loss = outputs[0] + self.log("train_loss", loss) + return loss + + def common_step(self, prefix: str, batch: Any) -> torch.Tensor: + generated_tokens = self.predict(batch, skip_collate_fn=True) + self.compute_metrics(generated_tokens, batch, prefix) + + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + self.common_step("val", batch) + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): + self.common_step("test", batch) + + def compute_metrics(self, generated_tokens, batch, prefix): + pass + + @property + def task(self) -> Optional[str]: + """ + Override to define AutoConfig task specific parameters stored within the model. + """ + pass + + def _initialize_model_specific_parameters(self): + task_specific_params = self.model.config.task_specific_params + + if task_specific_params is not None: + pars = task_specific_params.get(self.task, {}) + rank_zero_info(f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}") + self.model.config.update(pars) + + @property + def tokenizer(self) -> PreTrainedTokenizerBase: + return self.data_pipeline.tokenizer + + def tokenize_labels(self, labels: torch.Tensor) -> List[str]: + label_str = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + return [str.strip(s) for s in label_str] + + def configure_finetune_callback(self) -> List[FlashBaseFinetuning]: + return [Seq2SeqFreezeEmbeddings(self.model.config.model_type, train_bn=True)] diff --git a/flash/text/seq2seq/summarization/__init__.py b/flash/text/seq2seq/summarization/__init__.py new file mode 100644 index 0000000000..a9680d337f --- /dev/null +++ b/flash/text/seq2seq/summarization/__init__.py @@ -0,0 +1,2 @@ +from flash.text.seq2seq.summarization.data import SummarizationData +from flash.text.seq2seq.summarization.model import SummarizationTask diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py new file mode 100644 index 0000000000..20e0eb2ba2 --- /dev/null +++ b/flash/text/seq2seq/summarization/data.py @@ -0,0 +1,132 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from transformers import AutoTokenizer + +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqDataPipeline + + +class SummarizationData(Seq2SeqData): + from typing import Optional, Union + + @staticmethod + def default_pipeline(): + return Seq2SeqDataPipeline( + AutoTokenizer.from_pretrained("t5-small", use_fast=True), + input="input", + ) + + @classmethod + def from_files( + cls, + train_file: str, + input: str = 'input', + target: Optional[str] = None, + filetype: str = "csv", + backbone: str = "t5-small", + valid_file: str = None, + test_file: str = None, + max_source_length: int = 512, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + batch_size: int = 16, + num_workers: Optional[int] = None, + ): + """Creates a SummarizationData object from files. + + Args: + train_file: Path to training data. + input: The field storing the source translation text. + target: The field storing the target translation text. + filetype: .csv or .json + backbone: tokenizer to use, can use any HuggingFace tokenizer. + valid_file: Path to validation data. + test_file: Path to test data. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 16. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + SummarizationData: The constructed data module. + + Examples:: + + train_df = pd.read_csv("train_data.csv") + tab_data = TabularData.from_df(train_df, target="fraud", + numerical_input=["account_value"], + categorical_input=["account_type"]) + + """ + return super().from_files( + train_file=train_file, + valid_file=valid_file, + test_file=test_file, + input=input, + target=target, + backbone=backbone, + filetype=filetype, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + batch_size=batch_size, + num_workers=num_workers + ) + + @classmethod + def from_file( + cls, + predict_file: str, + input: str = 'src_text', + target: Optional[str] = None, + backbone: str = "t5-small", + filetype: str = "csv", + max_source_length: int = 512, + max_target_length: int = 128, + padding: Union[str, bool] = 'longest', + batch_size: int = 16, + num_workers: Optional[int] = None, + ): + """Creates a SummarizationData object from files. + + Args: + predict_file: Path to prediction input file. + input: The field storing the source translation text. + target: The field storing the target translation text. + backbone: tokenizer to use, can use any HuggingFace tokenizer. + filetype: csv or json. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 16. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + SummarizationData: The constructed data module. + + """ + return super().from_file( + predict_file=predict_file, + input=input, + target=target, + backbone=backbone, + filetype=filetype, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + batch_size=batch_size, + num_workers=num_workers + ) diff --git a/flash/text/seq2seq/summarization/metric.py b/flash/text/seq2seq/summarization/metric.py new file mode 100644 index 0000000000..58581ac0c3 --- /dev/null +++ b/flash/text/seq2seq/summarization/metric.py @@ -0,0 +1,109 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Tuple + +import numpy as np +import torch +from pytorch_lightning.metrics import Metric +from rouge_score import rouge_scorer, scoring +from rouge_score.scoring import AggregateScore, Score + +from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence + + +class RougeMetric(Metric): + """ + Metric used for automatic summarization. https://www.aclweb.org/anthology/W04-1013/ + """ + + def __init__( + self, + rouge_newline_sep: bool, + use_stemmer: bool, + rouge_keys: Tuple[str] = ("rouge1", "rouge2", "rougeL", "rougeLsum"), + ): + super().__init__() + self.rouge_newline_sep = rouge_newline_sep + self.rouge_keys = rouge_keys + self.use_stemmer = use_stemmer + self.aggregator = RougeBatchAggregator() + self.scorer = rouge_scorer.RougeScorer(rouge_keys, use_stemmer=self.use_stemmer) + + for key in rouge_keys: + self.add_state(key, []) + + def update(self, pred_lns: List[str], tgt_lns: List[str]): + for pred, tgt in zip(pred_lns, tgt_lns): + # rougeLsum expects "\n" separated sentences within a summary + if self.rouge_newline_sep: + pred = add_newline_to_end_of_each_sentence(pred) + tgt = add_newline_to_end_of_each_sentence(tgt) + results = self.scorer.score(pred, tgt) + for key, score in results.items(): + score = torch.tensor([score.precision, score.recall, score.fmeasure]) + getattr(self, key).append(score) + + def compute(self) -> Dict[str, float]: + scores = {key: getattr(self, key) for key in self.rouge_keys} + self.aggregator.add_scores(scores) + result = self.aggregator.aggregate() + return format_rouge_results(result) + + def __hash__(self): + # override to hash list objects. + # this is a bug in the upstream pytorch release. + hash_vals = [self.__class__.__name__] + + for key in self._defaults.keys(): + value = getattr(self, key) + if isinstance(value, list): + value = tuple(value) + hash_vals.append(value) + + return hash(tuple(hash_vals)) + + +class RougeBatchAggregator(scoring.BootstrapAggregator): + """ + Aggregates rouge scores and provides confidence intervals. + """ + + def aggregate(self): + """ + Override function to wrap the final results in `Score` objects. + This is due to the scores being replaced with a list of torch tensors. + """ + result = {} + for score_type, scores in self._scores.items(): + # Stack scores into a 2-d matrix of (sample, measure). + score_matrix = np.vstack(tuple(scores)) + # Percentiles are returned as (interval, measure). + percentiles = self._bootstrap_resample(score_matrix) + # Extract the three intervals (low, mid, high). + intervals = tuple((Score(*percentiles[j, :]) for j in range(3))) + result[score_type] = AggregateScore(low=intervals[0], mid=intervals[1], high=intervals[2]) + return result + + def add_scores(self, scores): + self._scores = scores + + +def format_rouge_results(result: Dict[str, AggregateScore], decimal_places: int = 4) -> Dict[str, float]: + flattened_result = {} + for rouge_key, rouge_aggregate_score in result.items(): + for stat in ["precision", "recall", "fmeasure"]: + mid = rouge_aggregate_score.mid + score = round(getattr(mid, stat), decimal_places) + flattened_result[f"{rouge_key}_{stat}"] = score + return flattened_result diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py new file mode 100644 index 0000000000..ed0c381a71 --- /dev/null +++ b/flash/text/seq2seq/summarization/model.py @@ -0,0 +1,59 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Mapping, Optional, Sequence, Type, Union + +import pytorch_lightning as pl +import torch + +from flash.text.seq2seq.core.model import Seq2SeqTask +from flash.text.seq2seq.summarization.metric import RougeMetric + + +class SummarizationTask(Seq2SeqTask): + + def __init__( + self, + backbone: str = "t5-small", + loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 5e-5, + val_target_max_length: Optional[int] = None, + num_beams: Optional[int] = 4, + use_stemmer: bool = True, + rouge_newline_sep: bool = True + ): + self.save_hyperparameters() + super().__init__( + backbone=backbone, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + val_target_max_length=val_target_max_length, + num_beams=num_beams + ) + self.rouge = RougeMetric( + rouge_newline_sep=rouge_newline_sep, + use_stemmer=use_stemmer, + ) + + @property + def task(self) -> str: + return "summarization" + + def compute_metrics(self, generated_tokens, batch, prefix): + tgt_lns = self.tokenize_labels(batch["labels"]) + result = self.rouge(generated_tokens, tgt_lns) + self.log_dict(result, on_step=False, on_epoch=True) diff --git a/flash/text/seq2seq/summarization/utils.py b/flash/text/seq2seq/summarization/utils.py new file mode 100644 index 0000000000..5af6ce679a --- /dev/null +++ b/flash/text/seq2seq/summarization/utils.py @@ -0,0 +1,30 @@ +# Copyright 2020 The PyTorch Lightning team and The HuggingFace Team. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import re + +from filelock import FileLock +from pytorch_lightning.utilities import _module_available + +nltk = None +if _module_available('nltk'): + import nltk + with FileLock(".lock") as lock: + nltk.download("punkt", quiet=True) + + +def add_newline_to_end_of_each_sentence(x: str) -> str: + """This was added to get rougeLsum scores matching published rougeL scores for BART and PEGASUS.""" + re.sub("", "", x) # remove pegasus newline char + assert nltk, "nltk must be installed to separate newlines between sentences. (pip install nltk)" + return "\n".join(nltk.sent_tokenize(x)) diff --git a/flash/text/seq2seq/translation/__init__.py b/flash/text/seq2seq/translation/__init__.py new file mode 100644 index 0000000000..d726952a05 --- /dev/null +++ b/flash/text/seq2seq/translation/__init__.py @@ -0,0 +1,2 @@ +from flash.text.seq2seq.translation.data import TranslationData +from flash.text.seq2seq.translation.model import TranslationTask diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py new file mode 100644 index 0000000000..afaf9b5cfb --- /dev/null +++ b/flash/text/seq2seq/translation/data.py @@ -0,0 +1,134 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +from transformers import AutoTokenizer + +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqDataPipeline + + +class TranslationData(Seq2SeqData): + """Data module for Translation tasks.""" + + @staticmethod + def default_pipeline(): + return Seq2SeqDataPipeline( + AutoTokenizer.from_pretrained("facebook/mbart-large-en-ro", use_fast=True), + input="input", + ) + + @classmethod + def from_files( + cls, + train_file, + input: str = 'input', + target: Optional[str] = None, + filetype="csv", + backbone="facebook/mbart-large-en-ro", + valid_file=None, + test_file=None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'max_length', + batch_size: int = 8, + num_workers: Optional[int] = None, + ): + """Creates a TranslateData object from files. + + Args: + train_file: Path to training data. + input: The field storing the source translation text. + target: The field storing the target translation text. + filetype: .csv or .json + backbone: tokenizer to use, can use any HuggingFace tokenizer. + valid_file: Path to validation data. + test_file: Path to test data. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 8. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + TranslateData: The constructed data module. + + Examples:: + + train_df = pd.read_csv("train_data.csv") + tab_data = TabularData.from_df(train_df, target="fraud", + numerical_input=["account_value"], + categorical_input=["account_type"]) + + """ + return super().from_files( + train_file=train_file, + valid_file=valid_file, + test_file=test_file, + input=input, + target=target, + backbone=backbone, + filetype=filetype, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + batch_size=batch_size, + num_workers=num_workers + ) + + @classmethod + def from_file( + cls, + predict_file: str, + input: str = 'input', + target: Optional[str] = None, + backbone="facebook/mbart-large-en-ro", + filetype="csv", + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = 'longest', + batch_size: int = 8, + num_workers: Optional[int] = None, + ): + """Creates a TranslationData object from files. + + Args: + predict_file: Path to prediction input file. + input: The field storing the source translation text. + target: The field storing the target translation text. + backbone: tokenizer to use, can use any HuggingFace tokenizer. + filetype: csv or json. + max_source_length: Maximum length of the source text. Any text longer will be truncated. + max_target_length: Maximum length of the target text. Any text longer will be truncated. + padding: Padding strategy for batches. Default is pad to maximum length. + batch_size: the batchsize to use for parallel loading. Defaults to 8. + num_workers: The number of workers to use for parallelized loading. + Defaults to None which equals the number of available CPU threads. + + Returns: + Seq2SeqData: The constructed data module. + + """ + return super().from_file( + predict_file=predict_file, + input=input, + target=target, + backbone=backbone, + filetype=filetype, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + batch_size=batch_size, + num_workers=num_workers + ) diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py new file mode 100644 index 0000000000..5b29cdc90b --- /dev/null +++ b/flash/text/seq2seq/translation/metric.py @@ -0,0 +1,122 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# referenced from +# Library Name: torchtext +# Authors: torchtext authors and @sluks +# Date: 2020-07-18 +# Link: https://pytorch.org/text/_modules/torchtext/data/metrics.html#bleu_score +from collections import Counter +from typing import List + +import torch +from pytorch_lightning.metrics import Metric + + +def _count_ngram(ngram_input_list: List[str], n_gram: int) -> Counter: + """ + Counting how many times each word appears in a given text with ngram + Args: + ngram_input_list: A list of translated text or reference texts + n_gram: gram value ranged 1 to 4 + + Return: + ngram_counter: a collections.Counter object of ngram + """ + + ngram_counter = Counter() + + for i in range(1, n_gram + 1): + for j in range(len(ngram_input_list) - i + 1): + ngram_key = tuple(ngram_input_list[j:(i + j)]) + ngram_counter[ngram_key] += 1 + + return ngram_counter + + +class BLEUScore(Metric): + """ + Calculate BLEU score of machine translated text with one or more references. + + Example: + >>> translate_corpus = ['the cat is on the mat'.split()] + >>> reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + >>> metric = BLEUScore() + >>> metric(translate_corpus, reference_corpus) + tensor(0.7598) + """ + + def __init__(self, n_gram: int = 4, smooth: bool = False): + """ + Args: + n_gram: Gram value ranged from 1 to 4 (Default 4) + smooth: Whether or not to apply smoothing – Lin et al. 2004 + """ + super().__init__() + self.n_gram = n_gram + self.smooth = smooth + + self.add_state("c", torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("r", torch.tensor(0, dtype=torch.float), dist_reduce_fx="sum") + self.add_state("numerator", torch.zeros(self.n_gram), dist_reduce_fx="sum") + self.add_state("denominator", torch.zeros(self.n_gram), dist_reduce_fx="sum") + + def compute(self): + + trans_len = self.c.clone().detach() + ref_len = self.r.clone().detach() + + if min(self.numerator) == 0.0: + return torch.tensor(0.0, device=self.r.device) + + if self.smooth: + precision_scores = torch.add(self.numerator, torch.ones( + self.n_gram + )) / torch.add(self.denominator, torch.ones(self.n_gram)) + else: + precision_scores = self.numerator / self.denominator + + log_precision_scores = torch.tensor([1.0 / self.n_gram] * self.n_gram, + device=self.r.device) * torch.log(precision_scores) + geometric_mean = torch.exp(torch.sum(log_precision_scores)) + brevity_penalty = ( + torch.tensor(1.0, device=self.r.device) if self.c > self.r else torch.exp(1 - (ref_len / trans_len)) + ) + bleu = brevity_penalty * geometric_mean + return bleu + + def update(self, translate_corpus, reference_corpus) -> None: + """ + Actual metric computation + Args: + translate_corpus: An iterable of machine translated corpus + reference_corpus: An iterable of iterables of reference corpus + """ + for (translation, references) in zip(translate_corpus, reference_corpus): + self.c += len(translation) + ref_len_list = [len(ref) for ref in references] + ref_len_diff = [abs(len(translation) - x) for x in ref_len_list] + self.r += ref_len_list[ref_len_diff.index(min(ref_len_diff))] + translation_counter = _count_ngram(translation, self.n_gram) + reference_counter = Counter() + + for ref in references: + reference_counter |= _count_ngram(ref, self.n_gram) + + ngram_counter_clip = translation_counter & reference_counter + + for counter_clip in ngram_counter_clip: + self.numerator[len(counter_clip) - 1] += ngram_counter_clip[counter_clip] + + for counter in translation_counter: + self.denominator[len(counter) - 1] += translation_counter[counter] diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py new file mode 100644 index 0000000000..4ea25c8afb --- /dev/null +++ b/flash/text/seq2seq/translation/model.py @@ -0,0 +1,73 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Callable, Mapping, Optional, Sequence, Type, Union + +import pytorch_lightning as pl +import torch + +from flash.text.seq2seq.core.model import Seq2SeqTask +from flash.text.seq2seq.translation.metric import BLEUScore + + +class TranslationTask(Seq2SeqTask): + """Task for Sequence2Sequence Translation. + Args: + backbone: backbone model to use for the task. + loss_fn: Loss function for training. + optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. + metrics: Metrics to compute for training and evaluation. + learning_rate: Learning rate to use for training, defaults to `3e-4` + val_target_max_length: Maximum length of targets in validation. Defaults to `128` + num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` + n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` + smooth: Apply smoothing in BLEU calculation. Defaults to `True` + """ + + def __init__( + self, + backbone: str = "facebook/mbart-large-en-ro", + loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, + optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, + metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, + learning_rate: float = 3e-4, + val_target_max_length: Optional[int] = 128, + num_beams: Optional[int] = 4, + n_gram: bool = 4, + smooth: bool = False, + ): + self.save_hyperparameters() + super().__init__( + backbone=backbone, + loss_fn=loss_fn, + optimizer=optimizer, + metrics=metrics, + learning_rate=learning_rate, + val_target_max_length=val_target_max_length, + num_beams=num_beams, + ) + self.bleu = BLEUScore( + n_gram=n_gram, + smooth=smooth, + ) + + @property + def task(self) -> str: + return "translation" + + def compute_metrics(self, generated_tokens, batch, prefix): + tgt_lns = self.tokenize_labels(batch["labels"]) + # wrap targets in list as score expects a list of potential references + tgt_lns = [[reference] for reference in tgt_lns] + result = self.bleu(generated_tokens, tgt_lns) + self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) diff --git a/flash/vision/classification/model.py b/flash/vision/classification/model.py index 66c38fa08c..ef49b18f04 100644 --- a/flash/vision/classification/model.py +++ b/flash/vision/classification/model.py @@ -16,6 +16,7 @@ import torch import torchvision from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F @@ -51,6 +52,7 @@ def __init__( self, num_classes, backbone="resnet18", + num_features: int = None, pretrained=True, loss_fn: Callable = F.cross_entropy, optimizer: Type[torch.optim.Optimizer] = torch.optim.SGD, @@ -68,7 +70,7 @@ def __init__( self.save_hyperparameters() if backbone not in _backbones: - raise NotImplementedError(f"Backbone {backbone} is not yet supported") + raise MisconfigurationException(f"Backbone {backbone} is not yet supported") backbone_fn, split, num_feats = _backbones[backbone] backbone = backbone_fn(pretrained=pretrained) diff --git a/flash_examples/finetuning/summarization.py b/flash_examples/finetuning/summarization.py new file mode 100644 index 0000000000..806ce8bfcf --- /dev/null +++ b/flash_examples/finetuning/summarization.py @@ -0,0 +1,31 @@ +import flash +from flash import download_data +from flash.text import SummarizationData, SummarizationTask + +if __name__ == "__main__": + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') + + # 2. Load the data + datamodule = SummarizationData.from_files( + train_file="data/xsum/train.csv", + valid_file="data/xsum/valid.csv", + test_file="data/xsum/test.csv", + input="input", + target="target" + ) + + # 3. Build the model + model = SummarizationTask() + + # 4. Create the trainer. Run once on data + trainer = flash.Trainer(max_epochs=1) + + # 5. Fine-tune the model + trainer.finetune(model, datamodule=datamodule) + + # 6. Test model + trainer.test() + + # 7. Save it! + trainer.save_checkpoint("summarization_model_xsum.pt") diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py new file mode 100644 index 0000000000..e7f1debce3 --- /dev/null +++ b/flash_examples/finetuning/translation.py @@ -0,0 +1,31 @@ +import flash +from flash import download_data +from flash.text import TranslationData, TranslationTask + +if __name__ == "__main__": + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') + + # 2. Load the data + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + ) + + # 3. Build the model + model = TranslationTask() + + # 4. Create the trainer. Run once on data + trainer = flash.Trainer(max_epochs=1, precision=16, gpus=1) + + # 5. Fine-tune the model + trainer.finetune(model, datamodule=datamodule) + + # 6. Test model + trainer.test() + + # 7. Save it! + trainer.save_checkpoint("translation_model_en_ro.pt") diff --git a/flash_examples/predict/summarize.py b/flash_examples/predict/summarize.py new file mode 100644 index 0000000000..1cd5e68e4e --- /dev/null +++ b/flash_examples/predict/summarize.py @@ -0,0 +1,44 @@ +from pytorch_lightning import Trainer + +from flash.core.data import download_data +from flash.text import SummarizationData, SummarizationTask + +if __name__ == "__main__": + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') + + # 2. Load the model from a checkpoint + model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") + + # 2a. Summarize an article! + predictions = model.predict([ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ + ]) + print(predictions) + + # 2b. Or generate summaries from a sheet file! + datamodule = SummarizationData.from_file( + predict_file="data/xsum/predict.csv", + input="input", + ) + predictions = Trainer().predict(model, datamodule=datamodule) + print(predictions) diff --git a/flash_examples/predict/translate.py b/flash_examples/predict/translate.py new file mode 100644 index 0000000000..4003b689d0 --- /dev/null +++ b/flash_examples/predict/translate.py @@ -0,0 +1,27 @@ +from pytorch_lightning import Trainer + +from flash import download_data +from flash.text import TranslationData, TranslationTask + +if __name__ == "__main__": + + # 1. Download the data + download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') + + # 2. Load the model from a checkpoint + model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") + + # 2a. Translate a few sentences! + predictions = model.predict([ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + ]) + print(predictions) + + # 2b. Or generate translations from a sheet file! + datamodule = TranslationData.from_file( + predict_file="data/wmt_en_ro/predict.csv", + input="input", + ) + predictions = Trainer().predict(model, datamodule=datamodule) + print(predictions) diff --git a/flash_notebooks/generic_task.ipynb b/flash_notebooks/generic_task.ipynb index f0133b429c..8721cca796 100644 --- a/flash_notebooks/generic_task.ipynb +++ b/flash_notebooks/generic_task.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "unlike-price", + "id": "outstanding-knight", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "requested-hostel", + "id": "british-auckland", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by creating a ClassificationTask with a custom Convolutional Model and train it on [MNIST Dataset](http://yann.lecun.com/exdb/mnist/)\n", @@ -26,7 +26,7 @@ }, { "cell_type": "markdown", - "id": "historic-cowboy", + "id": "chicken-bradford", "metadata": {}, "source": [ "# Training" @@ -35,7 +35,7 @@ { "cell_type": "code", "execution_count": null, - "id": "intermediate-rebecca", + "id": "impaired-trick", "metadata": {}, "outputs": [], "source": [ @@ -45,18 +45,10 @@ }, { "cell_type": "code", - "execution_count": 1, - "id": "proof-plenty", + "execution_count": null, + "id": "technological-certification", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "PyTorch version 1.7.1 available.\n" - ] - } - ], + "outputs": [], "source": [ "import pytorch_lightning as pl\n", "from torch import nn, optim\n", @@ -68,7 +60,7 @@ }, { "cell_type": "markdown", - "id": "smart-factor", + "id": "several-board", "metadata": {}, "source": [ "### 1. Load a basic backbone" @@ -76,8 +68,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "entitled-opera", + "execution_count": null, + "id": "upper-quest", "metadata": {}, "outputs": [], "source": [ @@ -91,7 +83,7 @@ }, { "cell_type": "markdown", - "id": "restricted-tooth", + "id": "faced-captain", "metadata": {}, "source": [ "### 2. Load a dataset" @@ -99,122 +91,17 @@ }, { "cell_type": "code", - "execution_count": 3, - "id": "polish-duncan", + "execution_count": null, + "id": "welcome-hammer", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cef56b96ec38400b8d8b2acadcd20f58", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7b13a282cce148f9ac5946aa949935ec", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2b79c71d03464940a48de53bda953f82", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6d1a06fc73f34a988a333f386d4fed67", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw\n", - "Processing...\n", - "Done!\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:141.)\n", - " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" - ] - } - ], + "outputs": [], "source": [ "dataset = datasets.MNIST('./data', download=True, transform=transforms.ToTensor())" ] }, { "cell_type": "markdown", - "id": "starting-edmonton", + "id": "banned-gardening", "metadata": {}, "source": [ "### 3. Split the data randomly" @@ -222,8 +109,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "id": "indonesian-arlington", + "execution_count": null, + "id": "southwest-muscle", "metadata": {}, "outputs": [], "source": [ @@ -232,7 +119,7 @@ }, { "cell_type": "markdown", - "id": "configured-bones", + "id": "formal-carnival", "metadata": {}, "source": [ "### 4. Create the model" @@ -240,8 +127,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "fleet-breast", + "execution_count": null, + "id": "essential-community", "metadata": {}, "outputs": [], "source": [ @@ -250,7 +137,7 @@ }, { "cell_type": "markdown", - "id": "vulnerable-shirt", + "id": "controlling-combination", "metadata": {}, "source": [ "### 5. Create the trainer" @@ -258,19 +145,10 @@ }, { "cell_type": "code", - "execution_count": 6, - "id": "assigned-bahamas", + "execution_count": null, + "id": "altered-wealth", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "GPU available: False, used: False\n", - "TPU available: None, using: 0 TPU cores\n" - ] - } - ], + "outputs": [], "source": [ "trainer = pl.Trainer(\n", " max_epochs=10,\n", @@ -281,7 +159,7 @@ }, { "cell_type": "markdown", - "id": "unavailable-sodium", + "id": "worldwide-fashion", "metadata": {}, "source": [ "### 6. Train the model" @@ -289,222 +167,17 @@ }, { "cell_type": "code", - "execution_count": 7, - "id": "grave-complaint", + "execution_count": null, + "id": "according-ebony", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - " | Name | Type | Params\n", - "---------------------------------------\n", - "0 | model | Sequential | 101 K \n", - "1 | metrics | ModuleDict | 0 \n", - "---------------------------------------\n", - "101 K Trainable params\n", - "0 Non-trainable params\n", - "101 K Total params\n", - "0.407 Total estimated model params size (MB)\n", - "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validation sanity check: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, train dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "95bfb98e9b7b4b5ca6d8cb94f3825e15", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Training: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Validating: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "1" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "trainer.fit(classifier, DataLoader(train), DataLoader(val))" ] }, { "cell_type": "markdown", - "id": "excellent-detail", + "id": "spread-chambers", "metadata": {}, "source": [ "### 7. Test the model" @@ -512,50 +185,17 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "sized-string", + "execution_count": null, + "id": "molecular-retention", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/thomas/Documents/GitHub/lightning-flash/.venv/lib/python3.6/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", - " warnings.warn(*args, **kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0c7202f53c22480a8ec2b9bbdbdde631", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Testing: 0it [00:00, ?it/s]" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--------------------------------------------------------------------------------\n", - "DATALOADER:0 TEST RESULTS\n", - "{'test_cross_entropy': 1.5057185888290405}\n", - "--------------------------------------------------------------------------------\n" - ] - } - ], + "outputs": [], "source": [ "results = trainer.test(classifier, test_dataloaders=DataLoader(test))" ] }, { "cell_type": "markdown", - "id": "endless-contrary", + "id": "charitable-night", "metadata": {}, "source": [ "# Predicting" @@ -563,28 +203,17 @@ }, { "cell_type": "code", - "execution_count": 13, - "id": "informative-arnold", + "execution_count": null, + "id": "continued-daisy", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[9, 1, 8, 5, 5, 6, 6, 3, 3, 5, 5, 3, 8, 1, 2, 7, 3, 9, 8, 1, 4, 3, 8, 0, 3]" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "classifier.predict(predict)" ] }, { "cell_type": "markdown", - "id": "dimensional-breakfast", + "id": "nominated-found", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/image_classification.ipynb b/flash_notebooks/image_classification.ipynb index ea39880573..bf46fceb63 100644 --- a/flash_notebooks/image_classification.ipynb +++ b/flash_notebooks/image_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "western-queue", + "id": "periodic-lobby", "metadata": {}, "source": [ "\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "democratic-alpha", + "id": "indonesian-rogers", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetuning/predictin with an ImageClassifier on [Hymenoptera Dataset](https://www.kaggle.com/ajayrana/hymenoptera-data) containing ants and bees images.\n", @@ -43,7 +43,7 @@ { "cell_type": "code", "execution_count": null, - "id": "parallel-integrity", + "id": "jewish-control", "metadata": {}, "outputs": [], "source": [ @@ -54,7 +54,7 @@ { "cell_type": "code", "execution_count": null, - "id": "worth-wealth", + "id": "technical-story", "metadata": {}, "outputs": [], "source": [ @@ -65,7 +65,7 @@ }, { "cell_type": "markdown", - "id": "frequent-memorial", + "id": "patient-bandwidth", "metadata": {}, "source": [ "## 1. Download data\n", @@ -75,7 +75,7 @@ { "cell_type": "code", "execution_count": null, - "id": "planned-greene", + "id": "informative-handle", "metadata": {}, "outputs": [], "source": [ @@ -84,7 +84,7 @@ }, { "cell_type": "markdown", - "id": "changed-perry", + "id": "cubic-dialogue", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -107,7 +107,7 @@ { "cell_type": "code", "execution_count": null, - "id": "synthetic-hamburg", + "id": "painted-morocco", "metadata": {}, "outputs": [], "source": [ @@ -120,7 +120,7 @@ }, { "cell_type": "markdown", - "id": "common-testing", + "id": "sophisticated-charity", "metadata": {}, "source": [ "### 3. Build the model\n", @@ -132,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "religious-pasta", + "id": "characteristic-victory", "metadata": {}, "outputs": [], "source": [ @@ -141,7 +141,7 @@ }, { "cell_type": "markdown", - "id": "accurate-thread", + "id": "clear-panic", "metadata": {}, "source": [ "### 4. Create the trainer. Run once on data\n", @@ -158,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "rural-silly", + "id": "floral-montana", "metadata": {}, "outputs": [], "source": [ @@ -167,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "trained-unemployment", + "id": "coupled-detection", "metadata": {}, "source": [ "### 5. Finetune the model" @@ -176,7 +176,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bound-printer", + "id": "disturbed-williams", "metadata": {}, "outputs": [], "source": [ @@ -185,7 +185,7 @@ }, { "cell_type": "markdown", - "id": "european-incentive", + "id": "ruled-telescope", "metadata": {}, "source": [ "### 6. Test the model" @@ -194,7 +194,7 @@ { "cell_type": "code", "execution_count": null, - "id": "ceramic-dress", + "id": "binding-nudist", "metadata": {}, "outputs": [], "source": [ @@ -203,7 +203,7 @@ }, { "cell_type": "markdown", - "id": "cheap-residence", + "id": "protected-longer", "metadata": {}, "source": [ "### 7. Save it!" @@ -212,7 +212,7 @@ { "cell_type": "code", "execution_count": null, - "id": "micro-favor", + "id": "variable-humidity", "metadata": {}, "outputs": [], "source": [ @@ -221,7 +221,76 @@ }, { "cell_type": "markdown", - "id": "associate-demonstration", + "id": "peripheral-saying", + "metadata": {}, + "source": [ + "# Predicting" + ] + }, + { + "cell_type": "markdown", + "id": "nonprofit-level", + "metadata": {}, + "source": [ + "### 1. Load the model from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "incredible-permit", + "metadata": {}, + "outputs": [], + "source": [ + "model = ImageClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/image_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "serial-memphis", + "metadata": {}, + "source": [ + "### 2a. Predict what's on a few images! ants or bees?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "agricultural-terminal", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict([\n", + " \"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg\",\n", + " \"data/hymenoptera_data/val/bees/590318879_68cf112861.jpg\",\n", + " \"data/hymenoptera_data/val/ants/540543309_ddbb193ee5.jpg\",\n", + "])\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "oriental-plate", + "metadata": {}, + "source": [ + "### 2b. Or generate predictions with a whole folder!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "posted-cambodia", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = ImageClassificationData.from_folder(folder=\"data/hymenoptera_data/predict/\")\n", + "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "international-shannon", "metadata": {}, "source": [ "\n", diff --git a/flash_notebooks/text_classification.ipynb b/flash_notebooks/text_classification.ipynb index 411886c941..7f87d2852c 100644 --- a/flash_notebooks/text_classification.ipynb +++ b/flash_notebooks/text_classification.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "satellite-bidding", + "id": "exposed-festival", "metadata": {}, "source": [ "
\n", @@ -12,7 +12,7 @@ }, { "cell_type": "markdown", - "id": "minute-father", + "id": "compact-writer", "metadata": {}, "source": [ "In this notebook, we'll go over the basics of lightning Flash by finetunig a TextClassifier on [IMDB Dataset](https://www.imdb.com/interfaces/).\n", @@ -42,7 +42,7 @@ }, { "cell_type": "markdown", - "id": "recent-footwear", + "id": "consistent-batch", "metadata": {}, "source": [ "### Setup \n", @@ -52,7 +52,7 @@ { "cell_type": "code", "execution_count": null, - "id": "proprietary-sheriff", + "id": "round-phrase", "metadata": {}, "outputs": [], "source": [ @@ -63,7 +63,7 @@ { "cell_type": "code", "execution_count": null, - "id": "signal-doctrine", + "id": "speaking-simpson", "metadata": {}, "outputs": [], "source": [ @@ -74,7 +74,7 @@ }, { "cell_type": "markdown", - "id": "sweet-insight", + "id": "passive-murray", "metadata": {}, "source": [ "### 1. Download the data\n", @@ -84,7 +84,7 @@ { "cell_type": "code", "execution_count": null, - "id": "prescribed-circuit", + "id": "entertaining-austin", "metadata": {}, "outputs": [], "source": [ @@ -93,7 +93,7 @@ }, { "cell_type": "markdown", - "id": "appointed-syndicate", + "id": "solar-journey", "metadata": {}, "source": [ "

2. Load the data

\n", @@ -105,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "suffering-sacramento", + "id": "reverse-chinese", "metadata": {}, "outputs": [], "source": [ @@ -121,7 +121,7 @@ }, { "cell_type": "markdown", - "id": "leading-latitude", + "id": "pleasant-integration", "metadata": { "jupyter": { "outputs_hidden": true @@ -138,7 +138,7 @@ { "cell_type": "code", "execution_count": null, - "id": "educational-toner", + "id": "studied-pickup", "metadata": {}, "outputs": [], "source": [ @@ -147,7 +147,7 @@ }, { "cell_type": "markdown", - "id": "limiting-iceland", + "id": "aboriginal-difference", "metadata": { "jupyter": { "outputs_hidden": true @@ -160,7 +160,7 @@ { "cell_type": "code", "execution_count": null, - "id": "potential-hypothesis", + "id": "beneficial-venue", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "markdown", - "id": "special-costume", + "id": "decent-sodium", "metadata": { "jupyter": { "outputs_hidden": true @@ -184,7 +184,7 @@ { "cell_type": "code", "execution_count": null, - "id": "lined-phoenix", + "id": "wrapped-portuguese", "metadata": {}, "outputs": [], "source": [ @@ -193,7 +193,7 @@ }, { "cell_type": "markdown", - "id": "derived-haven", + "id": "automatic-sewing", "metadata": { "jupyter": { "outputs_hidden": true @@ -206,7 +206,7 @@ { "cell_type": "code", "execution_count": null, - "id": "bearing-israel", + "id": "widespread-proposition", "metadata": {}, "outputs": [], "source": [ @@ -215,7 +215,7 @@ }, { "cell_type": "markdown", - "id": "enormous-resort", + "id": "fifth-census", "metadata": { "jupyter": { "outputs_hidden": true @@ -228,7 +228,7 @@ { "cell_type": "code", "execution_count": null, - "id": "caroline-jewelry", + "id": "outdoor-clearing", "metadata": {}, "outputs": [], "source": [ @@ -237,7 +237,81 @@ }, { "cell_type": "markdown", - "id": "arctic-directive", + "id": "asian-elimination", + "metadata": {}, + "source": [ + "# Predicting" + ] + }, + { + "cell_type": "markdown", + "id": "serial-windows", + "metadata": {}, + "source": [ + "### 1. Load the model from a checkpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "assigned-providence", + "metadata": {}, + "outputs": [], + "source": [ + "model = TextClassifier.load_from_checkpoint(\"https://flash-weights.s3.amazonaws.com/text_classification_model.pt\")" + ] + }, + { + "cell_type": "markdown", + "id": "attended-suspect", + "metadata": {}, + "source": [ + "### 2a. Classify a few sentences! How was the movie?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "interpreted-milan", + "metadata": {}, + "outputs": [], + "source": [ + "predictions = model.predict([\n", + " \"Turgid dialogue, feeble characterization - Harvey Keitel a judge?.\",\n", + " \"The worst movie in the history of cinema.\",\n", + " \"I come from Bulgaria where it 's almost impossible to have a tornado.\"\n", + " \"Very, very afraid\"\n", + " \"This guy has done a great job with this movie!\",\n", + "])\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "apparent-thunder", + "metadata": {}, + "source": [ + "### 2b. Or generate predictions from a sheet file!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "informed-insert", + "metadata": {}, + "outputs": [], + "source": [ + "datamodule = TextClassificationData.from_file(\n", + " predict_file=\"data/imdb/predict.csv\",\n", + " input=\"review\",\n", + ")\n", + "predictions = flash.Trainer().predict(model, datamodule=datamodule)\n", + "print(predictions)" + ] + }, + { + "cell_type": "markdown", + "id": "cultural-vaccine", "metadata": {}, "source": [ "\n", diff --git a/requirements.txt b/requirements.txt index 9182f75e0c..6450cb7a62 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,7 @@ datasets==1.2.1 pandas==1.1.2 scikit-learn==0.24.0 numpy -tqdm \ No newline at end of file +tqdm +rouge-score>=0.0.4 +sentencepiece>=0.1.95 +pytorch-lightning-bolts==0.3.0 diff --git a/tests/core/test_finetuning.py b/tests/core/test_finetuning.py index d518454c22..7572f1893a 100644 --- a/tests/core/test_finetuning.py +++ b/tests/core/test_finetuning.py @@ -11,28 +11,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Callable, Mapping, Optional, Sequence, Type, Union + import pytest +import pytorch_lightning as pl import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F -from flash import ClassificationTask, Trainer -from flash.core.finetuning import FlashBaseFinetuning -from tests.core.test_model import DummyDataset +from flash import Trainer +from flash.core.finetuning import NoFreeze +from flash.core.model import Task +from flash.vision.classification import ImageClassifier + + +class DummyDataset(torch.utils.data.Dataset): + + def __getitem__(self, index: int) -> Any: + return torch.rand(3, 64, 64), torch.randint(10, size=(1, )).item() + + def __len__(self) -> int: + return 100 @pytest.mark.parametrize( "strategy", ['no_freeze', 'freeze', 'freeze_unfreeze', 'unfreeze_milestones', None, 'cls', 'chocolat'] ) def test_finetuning(tmpdir: str, strategy): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) - task = ClassificationTask(model, F.nll_loss) + task = ImageClassifier(10, backbone="resnet18") trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": - strategy = FlashBaseFinetuning() + strategy = NoFreeze() if strategy == 'chocolat' or strategy is None: with pytest.raises(MisconfigurationException, match="strategy should be provided"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index e80a56bb48..e032dfaa55 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -8,7 +8,7 @@ from flash import ClassificationTask from flash.tabular import TabularClassifier -from flash.text import TextClassifier +from flash.text import SummarizationTask, TextClassifier, TranslationTask from flash.vision import ImageClassifier # ======== Mock functions ======== @@ -99,11 +99,16 @@ def test_task_datapipeline_save(tmpdir): assert task.data_pipeline.test -@pytest.mark.parametrize(["cls", "filename"], [ - (ImageClassifier, "image_classification_model.pt"), - (TabularClassifier, "tabular_classification_model.pt"), - (TextClassifier, "text_classification_model.pt"), -]) +@pytest.mark.parametrize( + ["cls", "filename"], + [ + (ImageClassifier, "image_classification_model.pt"), + (TabularClassifier, "tabular_classification_model.pt"), + (TextClassifier, "text_classification_model.pt"), + (SummarizationTask, "summarization_model_xsum.pt"), + # (TranslationTask, "translation_model_en_ro.pt"), todo: reduce model size or create CI friendly file size + ] +) def test_model_download(tmpdir, cls, filename): url = "https://flash-weights.s3.amazonaws.com/" with tmpdir.as_cwd(): diff --git a/tests/core/test_trainer.py b/tests/core/test_trainer.py index 1f403984fb..48cc9cd5b0 100644 --- a/tests/core/test_trainer.py +++ b/tests/core/test_trainer.py @@ -6,6 +6,7 @@ from torch.nn import functional as F from flash import ClassificationTask, Trainer +from flash.core.finetuning import Freeze, NoFreeze class DummyDataset(torch.utils.data.Dataset): @@ -51,5 +52,5 @@ def test_task_finetune(tmpdir: str): val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) - result = trainer.finetune(task, train_dl, val_dl, strategy="freeze") + result = trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze()) assert result diff --git a/tests/text/summarization/__init__.py b/tests/text/summarization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/summarization/test_data.py b/tests/text/summarization/test_data.py new file mode 100644 index 0000000000..3759634cd6 --- /dev/null +++ b/tests/text/summarization/test_data.py @@ -0,0 +1,79 @@ +import os +from pathlib import Path + +from flash.text import SummarizationData + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + +TEST_CSV_DATA = """input,target +this is a sentence one,this is a translated sentence one +this is a sentence two,this is a translated sentence two +this is a sentence three,this is a translated sentence three +""" + +TEST_JSON_DATA = """ +{"input": "this is a sentence one","target":"this is a translated sentence one"} +{"input": "this is a sentence two","target":"this is a translated sentence two"} +{"input": "this is a sentence three","target":"this is a translated sentence three"} +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA) + return path + + +def test_from_csv(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + csv_path = csv_data(tmpdir) + dm = SummarizationData.from_files( + backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +def test_from_files(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + csv_path = csv_data(tmpdir) + dm = SummarizationData.from_files( + backbone=TEST_BACKBONE, + train_file=csv_path, + valid_file=csv_path, + test_file=csv_path, + input="input", + target="target", + batch_size=1 + ) + batch = next(iter(dm.val_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + batch = next(iter(dm.test_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +def test_from_json(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + json_path = json_data(tmpdir) + dm = SummarizationData.from_files( + backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/summarization/test_model.py b/tests/text/summarization/test_model.py new file mode 100644 index 0000000000..a2b5efc444 --- /dev/null +++ b/tests/text/summarization/test_model.py @@ -0,0 +1,35 @@ +import os + +import torch + +from flash import Trainer +from flash.text import SummarizationTask + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + + def __getitem__(self, index): + return { + "input_ids": torch.randint(1000, size=(128, )), + "labels": torch.randint(1000, size=(128, )), + } + + def __len__(self): + return 100 + + +# ============================== + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + + +def test_init_train(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + model = SummarizationTask(TEST_BACKBONE) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) diff --git a/tests/text/translation/__init__.py b/tests/text/translation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/text/translation/test_data.py b/tests/text/translation/test_data.py new file mode 100644 index 0000000000..1bff6ba7a3 --- /dev/null +++ b/tests/text/translation/test_data.py @@ -0,0 +1,79 @@ +import os +from pathlib import Path + +from flash.text import TranslationData + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + +TEST_CSV_DATA = """input,target +this is a sentence one,this is a translated sentence one +this is a sentence two,this is a translated sentence two +this is a sentence three,this is a translated sentence three +""" + +TEST_JSON_DATA = """ +{"input": "this is a sentence one","target":"this is a translated sentence one"} +{"input": "this is a sentence two","target":"this is a translated sentence two"} +{"input": "this is a sentence three","target":"this is a translated sentence three"} +""" + + +def csv_data(tmpdir): + path = Path(tmpdir) / "data.csv" + path.write_text(TEST_CSV_DATA) + return path + + +def json_data(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA) + return path + + +def test_from_csv(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + csv_path = csv_data(tmpdir) + dm = TranslationData.from_files( + backbone=TEST_BACKBONE, train_file=csv_path, input="input", target="target", batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +def test_from_files(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + csv_path = csv_data(tmpdir) + dm = TranslationData.from_files( + backbone=TEST_BACKBONE, + train_file=csv_path, + valid_file=csv_path, + test_file=csv_path, + input="input", + target="target", + batch_size=1 + ) + batch = next(iter(dm.val_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + batch = next(iter(dm.test_dataloader())) + assert "labels" in batch + assert "input_ids" in batch + + +def test_from_json(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + json_path = json_data(tmpdir) + dm = TranslationData.from_files( + backbone=TEST_BACKBONE, train_file=json_path, input="input", target="target", filetype="json", batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/translation/test_model.py b/tests/text/translation/test_model.py new file mode 100644 index 0000000000..09ffff08b3 --- /dev/null +++ b/tests/text/translation/test_model.py @@ -0,0 +1,35 @@ +import os + +import torch + +from flash import Trainer +from flash.text import TranslationTask + +# ======== Mock functions ======== + + +class DummyDataset(torch.utils.data.Dataset): + + def __getitem__(self, index): + return { + "input_ids": torch.randint(1000, size=(128, )), + "labels": torch.randint(1000, size=(128, )), + } + + def __len__(self): + return 100 + + +# ============================== + +TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing + + +def test_init_train(tmpdir): + if os.name == "nt": + # TODO: huggingface stuff timing out on windows + return True + model = TranslationTask(TEST_BACKBONE) + train_dl = torch.utils.data.DataLoader(DummyDataset()) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_dl) diff --git a/tests/vision/classification/test_model.py b/tests/vision/classification/test_model.py index df52a865f6..63f5be5c1e 100644 --- a/tests/vision/classification/test_model.py +++ b/tests/vision/classification/test_model.py @@ -1,5 +1,6 @@ import pytest import torch +from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash import Trainer from flash.vision import ImageClassifier @@ -37,7 +38,7 @@ def test_init_train(tmpdir, backbone): def test_non_existent_backbone(): - with pytest.raises(NotImplementedError): + with pytest.raises(MisconfigurationException): ImageClassifier(2, "i am never going to implement this lol")