diff --git a/docs/source/index.rst b/docs/source/index.rst index cc1a7c5ffee..b928b50cfb8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,8 +21,10 @@ Lightning Flash reference/task reference/image_classification + reference/summarization reference/text_classification reference/tabular_classification + reference/translation .. toctree:: :maxdepth: 1 diff --git a/docs/source/reference/summarization.rst b/docs/source/reference/summarization.rst new file mode 100644 index 00000000000..eaf4cdb1b73 --- /dev/null +++ b/docs/source/reference/summarization.rst @@ -0,0 +1,178 @@ +.. _summarization: + +############# +Summarization +############# + +******** +The task +******** + +Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. +This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model. + +----- + +********* +Inference +********* + +The :class:`~flash.text.SummarizationTask` is already pre-trained on [XSUM](https://arxiv.org/abs/1808.08745), a dataset of online British Broadcasting Corporation articles. + +Use the :class:`~flash.text.SummarizationTask` pretrained model for inference on any string sequence using :func:`~flash.text.SummarizationTask.predict`: + +.. code-block:: python + + # import our libraries + from flash.text import SummarizationTask + + + # 2. Load the model from a checkpoint + model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") + + # 2. Perform inference from a sequence + predictions = model.predict([ + """ + Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local + people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. + They came to Brixton to see work which has started to revitalise the borough. + It was Charles' first visit to the area since 1996, when he was accompanied by the former + South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue + for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. + ""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. + She asked me were they ripe and I said yes - they're from the Dominican Republic."" + Mr Chong is one of 170 local retailers who accept the Brixton Pound. + Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market + or in participating shops. + During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children + nearby on an estate off Coldharbour Lane. Mr West said: + ""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" + He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" + Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. + The trust hopes to restore and refurbish the building, + where once Jimi Hendrix and The Clash played, as a new community and business centre." + """ + ]) + print(predictions) + +Or on a given dataset: + +.. code-block:: python + + # import our libraries + from flash import download_data + from flash.text import SummarizationTask + + # 2. Load the model from a checkpoint + model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") + + # 3. Perform inference from a csv file + predictions = model.predict("data/xsum/predict.csv") + print(predictions) + +For more advanced inference options, see :ref:`predictions`. + +----- + +********** +Finetuning +********** + +Say you want to finetune to your own summarization data. We use the XSUM dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: + +.. code-block:: + + input,target + "The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees." + "Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005. + ... + +In the above the input column represents the long articles/documents, and the target is the short description used as the target. + +All we need is three lines of code to train our model! + +.. code-block:: python + + # import our libraries + import flash + from flash import download_data + from flash.text import SummarizationData, SummarizationTask + + # 1. Download data + download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') + + # Organize the data + datamodule = SummarizationData.from_files( + train_file="data/xsum/train.csv", + valid_file="data/xsum/valid.csv", + test_file="data/xsum/test.csv", + input="input", + target="target" + ) + + # 2. Build the task + model = SummarizationTask() + + # 4. Create trainer + trainer = flash.Trainer(max_epochs=1, gpus=1) + + # 5. Finetune the task + trainer.finetune(model, datamodule=datamodule) + + # 6. Save trainer task + trainer.save_checkpoint("summarization_model_xsum.pt") + +---- + +To run the example: + +.. code-block:: bash + + python flash_examples/finetuning/summarization.py + + +------ + +********************* +Changing the backbone +********************* +By default, we use the `t5 `_ model for summarization. You can change the model run by passing in the backbone parameter. + +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. + +.. code-block:: python + + datamodule = SummarizationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + backbone="google/mt5-small", + ) + + model = SummarizationTask(backbone="google/mt5-small") + +------ + +************* +API reference +************* + +.. _summarization_task: + +SummarizationTask +-------------- + +.. autoclass:: flash.text.summarization.model.SummarizationTask + :members: + :exclude-members: forward + +.. _summarization_data: + +SummarizationData +---------------------- + +.. autoclass:: flash.text.summarization.data.SummarizationData + +.. automethod:: flash.text.summarization.data.SummarizationData.from_files diff --git a/docs/source/reference/text_classification.rst b/docs/source/reference/text_classification.rst index 0552d0bb0d7..66e7be505b0 100644 --- a/docs/source/reference/text_classification.rst +++ b/docs/source/reference/text_classification.rst @@ -16,9 +16,9 @@ Text classification is the task of assigning a piece of text (word, sentence or Inference ********* -The :class:`~flash.text.TextClassificatier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment. +The :class:`~flash.text.TextClassifier` is already pre-trained on [IMDB](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews), a dataset of highly polarized movie reviews, trained for binary classification- to predict if a given review has a positive or negative sentiment. -Use the :class:`~flash.text.TextClassificatier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: +Use the :class:`~flash.text.TextClassifier` pretrained model for inference on any string sequence using :func:`~flash.text.TextClassifier.predict`: .. code-block:: python @@ -166,6 +166,3 @@ TextClassificationData .. autoclass:: flash.text.classification.data.TextClassificationData .. automethod:: flash.text.classification.data.TextClassificationData.from_files - - - diff --git a/docs/source/reference/translation.rst b/docs/source/reference/translation.rst new file mode 100644 index 00000000000..f55cb70e6d7 --- /dev/null +++ b/docs/source/reference/translation.rst @@ -0,0 +1,160 @@ +.. _translation: + +################### +Translation +################### + +******** +The task +******** + +Translation is the task of translating text from a source language to another, such as English to Romanian. +This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the English text would be our input sequence, and the Romanian sentence would be the output sequence from the model. + +----- + +********* +Inference +********* + +The :class:`~flash.text.TranslationTask` is already pre-trained on [WMT16 English/Romanian](https://www.statmt.org/wmt16/translation-task.html), a dataset of English to Romanian samples, based on the Europarl corpora. + +Use the :class:`~flash.text.TranslationTask` pretrained model for inference on any string sequence using :func:`~flash.text.TranslationTask.predict`: + +.. code-block:: python + + # import our libraries + from flash.text import TranslationTask + + + # 2. Load the model from a checkpoint + model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") + + # 2. Perform inference from list of sequences + predictions = model.predict([ + "BBC News went to meet one of the project's first graduates.", + "A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", + ]) + print(predictions) + +Or on a given dataset: + +.. code-block:: python + + # import our libraries + from flash import download_data + from flash.text import TranslationTask + + # 2. Load the model from a checkpoint + model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") + + # 3. Perform inference from a csv file + predictions = model.predict("data/wmt_en_ro/predict.csv") + print(predictions) + +For more advanced inference options, see :ref:`predictions`. + +----- + +********** +Finetuning +********** + +Say you want to finetune to your own translation data. We use the English/Romanian WMT16 dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: + +.. code-block:: + + input,target + "Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal" + "Closure of sitting","Ridicarea şedinţei" + ... + +In the above the input/target columns represent the English and Romanian translation respectively. + +All we need is three lines of code to train our model! + +.. code-block:: python + + # import our libraries + import flash + from flash import download_data + from flash.text import TranslationData, TranslationTask + + # 1. Download data + download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') + + # Organize the data + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + ) + + # 2. Build the task + model = TranslationTask() + + # 4. Create trainer + trainer = flash.Trainer(max_epochs=5, gpus=1, precision=16) + + # 5. Finetune the task + trainer.finetune(model, datamodule=datamodule) + + # 6. Save trainer task + trainer.save_checkpoint("translation_model_en_ro.pt") + +---- + +To run the example: + +.. code-block:: bash + + python flash_examples/finetuning/translation.py + + +------ + +********************* +Changing the backbone +********************* +By default, we use the `MarianNMT `_ model for translation. You can change the model run by passing in the backbone parameter. + +.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. + +.. code-block:: python + + datamodule = TranslationData.from_files( + train_file="data/wmt_en_ro/train.csv", + valid_file="data/wmt_en_ro/valid.csv", + test_file="data/wmt_en_ro/test.csv", + input="input", + target="target", + backbone="t5-small", + ) + + model = TranslationTask(backbone="t5-small") + +------ + +************* +API reference +************* + +.. _translation_task: + +TranslationTask +-------------- + +.. autoclass:: flash.text.translation.model.TranslationTask + :members: + :exclude-members: forward + +.. _translation_data: + +TranslationData +---------------------- + +.. autoclass:: flash.text.translation.data.TranslationData + +.. automethod:: flash.text.translation.data.TranslationData.from_files