This repository has been archived by the owner on Oct 9, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
SeanNaren
committed
Feb 1, 2021
1 parent
f37a41a
commit 495b812
Showing
4 changed files
with
342 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
.. _summarization: | ||
|
||
############# | ||
Summarization | ||
############# | ||
|
||
******** | ||
The task | ||
******** | ||
|
||
Summarization is the task of summarizing text from a larger document/article into a short sentence/description. For example, taking a web article and describing the topic in a short sentence. | ||
This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the article would be our input sequence, and the short description/sentence would be the output sequence from the model. | ||
|
||
----- | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
The :class:`~flash.text.SummarizationTask` is already pre-trained on [XSUM](https://arxiv.org/abs/1808.08745), a dataset of online British Broadcasting Corporation articles. | ||
|
||
Use the :class:`~flash.text.SummarizationTask` pretrained model for inference on any string sequence using :func:`~flash.text.SummarizationTask.predict`: | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
from flash.text import SummarizationTask | ||
# 2. Load the model from a checkpoint | ||
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") | ||
# 2. Perform inference from a sequence | ||
predictions = model.predict([ | ||
""" | ||
Camilla bought a box of mangoes with a Brixton £10 note, introduced last year to try to keep the money of local | ||
people within the community.The couple were surrounded by shoppers as they walked along Electric Avenue. | ||
They came to Brixton to see work which has started to revitalise the borough. | ||
It was Charles' first visit to the area since 1996, when he was accompanied by the former | ||
South African president Nelson Mandela.Greengrocer Derek Chong, who has run a stall on Electric Avenue | ||
for 20 years, said Camilla had been ""nice and pleasant"" when she purchased the fruit. | ||
""She asked me what was nice, what would I recommend, and I said we've got some nice mangoes. | ||
She asked me were they ripe and I said yes - they're from the Dominican Republic."" | ||
Mr Chong is one of 170 local retailers who accept the Brixton Pound. | ||
Customers exchange traditional pound coins for Brixton Pounds and then spend them at the market | ||
or in participating shops. | ||
During the visit, Prince Charles spent time talking to youth worker Marcus West, who works with children | ||
nearby on an estate off Coldharbour Lane. Mr West said: | ||
""He's on the level, really down-to-earth. They were very cheery. The prince is a lovely man."" | ||
He added: ""I told him I was working with young kids and he said, 'Keep up all the good work.'"" | ||
Prince Charles also visited the Railway Hotel, at the invitation of his charity The Prince's Regeneration Trust. | ||
The trust hopes to restore and refurbish the building, | ||
where once Jimi Hendrix and The Clash played, as a new community and business centre." | ||
""" | ||
]) | ||
print(predictions) | ||
Or on a given dataset: | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
from flash import download_data | ||
from flash.text import SummarizationTask | ||
# 2. Load the model from a checkpoint | ||
model = SummarizationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/summarization_model_xsum.pt") | ||
# 3. Perform inference from a csv file | ||
predictions = model.predict("data/xsum/predict.csv") | ||
print(predictions) | ||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
----- | ||
|
||
********** | ||
Finetuning | ||
********** | ||
|
||
Say you want to finetune to your own summarization data. We use the XSUM dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: | ||
|
||
.. code-block:: | ||
input,target | ||
"The researchers have sequenced the genome of a strain of bacterium that causes the virulent infection...","A team of UK scientists hopes to shed light on the mysteries of bleeding canker, a disease that is threatening the nation's horse chestnut trees." | ||
"Knight was shot in the leg by an unknown gunman at Miami's Shore Club where West was holding a pre-MTV Awards...",Hip hop star Kanye West is being sued by Death Row Records founder Suge Knight over a shooting at a beach party in August 2005. | ||
... | ||
In the above the input column represents the long articles/documents, and the target is the short description used as the target. | ||
|
||
All we need is three lines of code to train our model! | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
import flash | ||
from flash import download_data | ||
from flash.text import SummarizationData, SummarizationTask | ||
# 1. Download data | ||
download_data("https://pl-flash-data.s3.amazonaws.com/xsum.zip", 'data/') | ||
# Organize the data | ||
datamodule = SummarizationData.from_files( | ||
train_file="data/xsum/train.csv", | ||
valid_file="data/xsum/valid.csv", | ||
test_file="data/xsum/test.csv", | ||
input="input", | ||
target="target" | ||
) | ||
# 2. Build the task | ||
model = SummarizationTask() | ||
# 4. Create trainer | ||
trainer = flash.Trainer(max_epochs=1, gpus=1) | ||
# 5. Finetune the task | ||
trainer.finetune(model, datamodule=datamodule) | ||
# 6. Save trainer task | ||
trainer.save_checkpoint("summarization_model_xsum.pt") | ||
---- | ||
|
||
To run the example: | ||
|
||
.. code-block:: bash | ||
python flash_examples/finetuning/summarization.py | ||
------ | ||
|
||
********************* | ||
Changing the backbone | ||
********************* | ||
By default, we use the `t5 <https://arxiv.org/abs/1910.10683>`_ model for summarization. You can change the model run by passing in the backbone parameter. | ||
|
||
.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. | ||
|
||
.. code-block:: python | ||
datamodule = SummarizationData.from_files( | ||
train_file="data/wmt_en_ro/train.csv", | ||
valid_file="data/wmt_en_ro/valid.csv", | ||
test_file="data/wmt_en_ro/test.csv", | ||
input="input", | ||
target="target", | ||
backbone="google/mt5-small", | ||
) | ||
model = SummarizationTask(backbone="google/mt5-small") | ||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
.. _summarization_task: | ||
|
||
SummarizationTask | ||
-------------- | ||
|
||
.. autoclass:: flash.text.summarization.model.SummarizationTask | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _summarization_data: | ||
|
||
SummarizationData | ||
---------------------- | ||
|
||
.. autoclass:: flash.text.summarization.data.SummarizationData | ||
|
||
.. automethod:: flash.text.summarization.data.SummarizationData.from_files |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
.. _translation: | ||
|
||
################### | ||
Translation | ||
################### | ||
|
||
******** | ||
The task | ||
******** | ||
|
||
Translation is the task of translating text from a source language to another, such as English to Romanian. | ||
This task is a subset of Sequence to Sequence tasks, which requires the model to generate a variable length sequence given an input sequence. In our case the English text would be our input sequence, and the Romanian sentence would be the output sequence from the model. | ||
|
||
----- | ||
|
||
********* | ||
Inference | ||
********* | ||
|
||
The :class:`~flash.text.TranslationTask` is already pre-trained on [WMT16 English/Romanian](https://www.statmt.org/wmt16/translation-task.html), a dataset of English to Romanian samples, based on the Europarl corpora. | ||
|
||
Use the :class:`~flash.text.TranslationTask` pretrained model for inference on any string sequence using :func:`~flash.text.TranslationTask.predict`: | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
from flash.text import TranslationTask | ||
# 2. Load the model from a checkpoint | ||
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") | ||
# 2. Perform inference from list of sequences | ||
predictions = model.predict([ | ||
"BBC News went to meet one of the project's first graduates.", | ||
"A recession has come as quickly as 11 months after the first rate hike and as long as 86 months.", | ||
]) | ||
print(predictions) | ||
Or on a given dataset: | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
from flash import download_data | ||
from flash.text import TranslationTask | ||
# 2. Load the model from a checkpoint | ||
model = TranslationTask.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/translation_model_en_ro.pt") | ||
# 3. Perform inference from a csv file | ||
predictions = model.predict("data/wmt_en_ro/predict.csv") | ||
print(predictions) | ||
For more advanced inference options, see :ref:`predictions`. | ||
|
||
----- | ||
|
||
********** | ||
Finetuning | ||
********** | ||
|
||
Say you want to finetune to your own translation data. We use the English/Romanian WMT16 dataset as an example which contains a ``train.csv`` and ``valid.csv``, structured like so: | ||
|
||
.. code-block:: | ||
input,target | ||
"Written statements and oral questions (tabling): see Minutes","Declaraţii scrise şi întrebări orale (depunere): consultaţi procesul-verbal" | ||
"Closure of sitting","Ridicarea şedinţei" | ||
... | ||
In the above the input/target columns represent the English and Romanian translation respectively. | ||
|
||
All we need is three lines of code to train our model! | ||
|
||
.. code-block:: python | ||
# import our libraries | ||
import flash | ||
from flash import download_data | ||
from flash.text import TranslationData, TranslationTask | ||
# 1. Download data | ||
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", 'data/') | ||
# Organize the data | ||
datamodule = TranslationData.from_files( | ||
train_file="data/wmt_en_ro/train.csv", | ||
valid_file="data/wmt_en_ro/valid.csv", | ||
test_file="data/wmt_en_ro/test.csv", | ||
input="input", | ||
target="target", | ||
) | ||
# 2. Build the task | ||
model = TranslationTask() | ||
# 4. Create trainer | ||
trainer = flash.Trainer(max_epochs=5, gpus=1, precision=16) | ||
# 5. Finetune the task | ||
trainer.finetune(model, datamodule=datamodule) | ||
# 6. Save trainer task | ||
trainer.save_checkpoint("translation_model_en_ro.pt") | ||
---- | ||
|
||
To run the example: | ||
|
||
.. code-block:: bash | ||
python flash_examples/finetuning/translation.py | ||
------ | ||
|
||
********************* | ||
Changing the backbone | ||
********************* | ||
By default, we use the `MarianNMT <https://marian-nmt.github.io/>`_ model for translation. You can change the model run by passing in the backbone parameter. | ||
|
||
.. note:: When changing the backbone, make sure you pass in the same backbone to the Task and the Data object! Since this is a Seq2Seq task, make sure you use a Seq2Seq model. | ||
|
||
.. code-block:: python | ||
datamodule = TranslationData.from_files( | ||
train_file="data/wmt_en_ro/train.csv", | ||
valid_file="data/wmt_en_ro/valid.csv", | ||
test_file="data/wmt_en_ro/test.csv", | ||
input="input", | ||
target="target", | ||
backbone="t5-small", | ||
) | ||
model = TranslationTask(backbone="t5-small") | ||
------ | ||
|
||
************* | ||
API reference | ||
************* | ||
|
||
.. _translation_task: | ||
|
||
TranslationTask | ||
-------------- | ||
|
||
.. autoclass:: flash.text.translation.model.TranslationTask | ||
:members: | ||
:exclude-members: forward | ||
|
||
.. _translation_data: | ||
|
||
TranslationData | ||
---------------------- | ||
|
||
.. autoclass:: flash.text.translation.data.TranslationData | ||
|
||
.. automethod:: flash.text.translation.data.TranslationData.from_files |