Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix summarization (#372)
Browse files Browse the repository at this point in the history
* Fix summarization

* Remove debug

* Add test

* fixes
  • Loading branch information
ethanwharris authored Jun 7, 2021
1 parent 7f8c3e8 commit bded76d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 8 deletions.
29 changes: 28 additions & 1 deletion flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, Optional, Union

from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class SummarizationPreprocess(Seq2SeqPreprocess):

def __init__(
self,
train_transform: Optional[Dict[str, Callable]] = None,
val_transform: Optional[Dict[str, Callable]] = None,
test_transform: Optional[Dict[str, Callable]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
backbone: str = "sshleifer/distilbart-xsum-1-1",
max_source_length: int = 128,
max_target_length: int = 128,
padding: Union[str, bool] = 'max_length'
):
super().__init__(
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
backbone=backbone,
max_source_length=max_source_length,
max_target_length=max_target_length,
padding=padding,
)


class SummarizationData(Seq2SeqData):

preprocess_cls = Seq2SeqPreprocess
preprocess_cls = SummarizationPreprocess
postprocess_cls = Seq2SeqPostprocess
2 changes: 1 addition & 1 deletion flash/text/seq2seq/summarization/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from flash.text.seq2seq.summarization.utils import add_newline_to_end_of_each_sentence

if _TEXT_AVAILABLE:
from rouge_score import rouge_scorer, scoring
from rouge_score import rouge_scorer
from rouge_score.scoring import AggregateScore, BootstrapAggregator, Score
else:
AggregateScore, Score, BootstrapAggregator = None, None, object
Expand Down
8 changes: 4 additions & 4 deletions flash/text/seq2seq/summarization/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ class SummarizationTask(Seq2SeqTask):

def __init__(
self,
backbone: str = "sshleifer/tiny-mbart",
backbone: str = "sshleifer/distilbart-xsum-1-1",
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 5e-5,
learning_rate: float = 1e-5,
val_target_max_length: Optional[int] = None,
num_beams: Optional[int] = 4,
use_stemmer: bool = True,
Expand Down Expand Up @@ -69,10 +69,10 @@ def task(self) -> str:
def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None:
tgt_lns = self.tokenize_labels(batch["labels"])
result = self.rouge(self._postprocess.uncollate(generated_tokens), tgt_lns)
self.log_dict(result, on_step=False, on_epoch=True)
self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
"""
assert history[-1]["val_f1"] > 0.45
assert history[-1]["rouge1_recall"] > 0.2
12 changes: 10 additions & 2 deletions tests/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ def run_test(filepath):
"semantic_segmentation.py",
marks=pytest.mark.skipif(not _IMAGE_AVAILABLE, reason="image libraries aren't installed")
),
# pytest.param("finetuning", "summarization.py"), # TODO: takes too long.
pytest.param(
"finetuning",
"summarization.py",
marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed")
),
pytest.param(
"finetuning",
"tabular_classification.py",
Expand Down Expand Up @@ -147,7 +151,11 @@ def run_test(filepath):
"video_classification.py",
marks=pytest.mark.skipif(not _VIDEO_AVAILABLE, reason="video libraries aren't installed")
),
# pytest.param("predict", "summarization.py"), # TODO: takes too long
pytest.param(
"predict",
"summarization.py",
marks=pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed")
),
pytest.param(
"predict",
"template.py",
Expand Down
26 changes: 26 additions & 0 deletions tests/text/seq2seq/summarization/test_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.seq2seq.summarization.metric import RougeMetric


@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.")
def test_rouge():
preds = "My name is John".split()
target = "Is your name John".split()
metric = RougeMetric()
assert torch.allclose(torch.tensor(metric(preds, target)["rouge1_recall"]).float(), torch.tensor(0.25), 1e-4)

0 comments on commit bded76d

Please sign in to comment.