Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix translation (#332)
Browse files Browse the repository at this point in the history
* Fix translation

* Fix example

* Fix

* Add metric test

* Fixes

* Fix pickle bug for classification

* Update CHANGELOG.md

* Add tests
  • Loading branch information
ethanwharris authored May 28, 2021
1 parent ae6801f commit 1054949
Show file tree
Hide file tree
Showing 17 changed files with 289 additions and 51 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))


- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


## [0.3.0] - 2021-05-20

### Added
Expand Down
46 changes: 46 additions & 0 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, backbone: str, max_length: int = 128):
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length

Expand All @@ -55,6 +56,15 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str,
ex[target] = label_to_class_mapping[ex[target]]
return ex

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextFileDataSource(TextDataSource):

Expand Down Expand Up @@ -115,18 +125,45 @@ def load_data(
def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextCSVDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("csv", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextJSONDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("json", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):

Expand All @@ -143,6 +180,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s, ) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextClassificationPreprocess(Preprocess):

Expand Down
90 changes: 88 additions & 2 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import flash
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _TEXT_AVAILABLE

Expand All @@ -45,7 +45,8 @@ def __init__(
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.padding = padding
Expand All @@ -71,6 +72,15 @@ def _tokenize_fn(
padding=self.padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqFileDataSource(Seq2SeqDataSource):

Expand Down Expand Up @@ -112,6 +122,15 @@ def load_data(
def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]:
return self.load_data(data, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqCSVDataSource(Seq2SeqFileDataSource):

Expand All @@ -130,6 +149,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqJSONDataSource(Seq2SeqFileDataSource):

Expand All @@ -148,6 +176,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqSentencesDataSource(Seq2SeqDataSource):

Expand All @@ -161,6 +198,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
Expand Down Expand Up @@ -240,7 +286,47 @@ def collate(self, samples: Any) -> Tensor:
return default_data_collator(samples)


class Seq2SeqPostprocess(Postprocess):

def __init__(self):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_str = [str.strip(s) for s in pred_str]
return pred_str

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("_tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqData(DataModule):
"""Data module for Seq2Seq tasks."""

preprocess_cls = Seq2SeqPreprocess
postprocess_cls = Seq2SeqPostprocess
41 changes: 2 additions & 39 deletions flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from flash.core.data.process import Postprocess
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess

if _TEXT_AVAILABLE:
from transformers import AutoTokenizer


class SummarizationPostprocess(Postprocess):

def __init__(self):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_str = [str.strip(s) for s in pred_str]
return pred_str
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class SummarizationData(Seq2SeqData):

preprocess_cls = Seq2SeqPreprocess
postprocess_cls = SummarizationPostprocess
postprocess_cls = Seq2SeqPostprocess
3 changes: 2 additions & 1 deletion flash/text/seq2seq/translation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Callable, Dict, Optional, Union

from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class TranslationPreprocess(Seq2SeqPreprocess):
Expand Down Expand Up @@ -45,3 +45,4 @@ class TranslationData(Seq2SeqData):
"""Data module for Translation tasks."""

preprocess_cls = TranslationPreprocess
postprocess_cls = Seq2SeqPostprocess
4 changes: 1 addition & 3 deletions flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def compute(self):
return tensor(0.0, device=self.r.device)

if self.smooth:
precision_scores = torch.add(self.numerator, torch.ones(
self.n_gram
)) / torch.add(self.denominator, torch.ones(self.n_gram))
precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
else:
precision_scores = self.numerator / self.denominator

Expand Down
10 changes: 5 additions & 5 deletions flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TranslationTask(Seq2SeqTask):
loss_fn: Loss function for training.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `3e-4`
learning_rate: Learning rate to use for training, defaults to `1e-5`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
Expand All @@ -41,11 +41,11 @@ def __init__(
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 3e-4,
learning_rate: float = 1e-5,
val_target_max_length: Optional[int] = 128,
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = False,
smooth: bool = True,
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -70,11 +70,11 @@ def compute_metrics(self, generated_tokens, batch, prefix):
tgt_lns = self.tokenize_labels(batch["labels"])
# wrap targets in list as score expects a list of potential references
tgt_lns = [[reference] for reference in tgt_lns]
result = self.bleu(generated_tokens, tgt_lns)
result = self.bleu(self._postprocess.uncollate(generated_tokens), tgt_lns)
self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
"""
# assert history[-1]["val_bleu_score"]
assert history[-1]["val_bleu_score"] > 0.6
2 changes: 1 addition & 1 deletion flash_examples/finetuning/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

backbone = "t5-small"
backbone = "Helsinki-NLP/opus-mt-en-ro"

# 2. Load the data
datamodule = TranslationData.from_csv(
Expand Down
Loading

0 comments on commit 1054949

Please sign in to comment.