Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix translation #332

Merged
merged 9 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343))


- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


## [0.3.0] - 2021-05-20

### Added
Expand Down
46 changes: 46 additions & 0 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, backbone: str, max_length: int = 128):
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length

Expand All @@ -55,6 +56,15 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str,
ex[target] = label_to_class_mapping[ex[target]]
return ex

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextFileDataSource(TextDataSource):

Expand Down Expand Up @@ -115,18 +125,45 @@ def load_data(
def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextCSVDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("csv", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextJSONDataSource(TextFileDataSource):

def __init__(self, backbone: str, max_length: int = 128):
super().__init__("json", backbone, max_length=max_length)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):

Expand All @@ -143,6 +180,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s, ) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextClassificationPreprocess(Preprocess):

Expand Down
90 changes: 88 additions & 2 deletions flash/text/seq2seq/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import flash
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataSources
from flash.core.data.process import Preprocess
from flash.core.data.process import Postprocess, Preprocess
from flash.core.data.properties import ProcessState
from flash.core.utilities.imports import _TEXT_AVAILABLE

Expand All @@ -45,7 +45,8 @@ def __init__(
if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.padding = padding
Expand All @@ -71,6 +72,15 @@ def _tokenize_fn(
padding=self.padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqFileDataSource(Seq2SeqDataSource):

Expand Down Expand Up @@ -112,6 +122,15 @@ def load_data(
def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]:
return self.load_data(data, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqCSVDataSource(Seq2SeqFileDataSource):

Expand All @@ -130,6 +149,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqJSONDataSource(Seq2SeqFileDataSource):

Expand All @@ -148,6 +176,15 @@ def __init__(
padding=padding,
)

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqSentencesDataSource(Seq2SeqDataSource):

Expand All @@ -161,6 +198,15 @@ def load_data(
data = [data]
return [self._tokenize_fn(s) for s in data]

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


@dataclass(unsafe_hash=True, frozen=True)
class Seq2SeqBackboneState(ProcessState):
Expand Down Expand Up @@ -240,7 +286,47 @@ def collate(self, samples: Any) -> Tensor:
return default_data_collator(samples)


class Seq2SeqPostprocess(Postprocess):

def __init__(self):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_str = [str.strip(s) for s in pred_str]
return pred_str

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("_tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class Seq2SeqData(DataModule):
"""Data module for Seq2Seq tasks."""

preprocess_cls = Seq2SeqPreprocess
postprocess_cls = Seq2SeqPostprocess
41 changes: 2 additions & 39 deletions flash/text/seq2seq/summarization/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,47 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

from flash.core.data.process import Postprocess
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess

if _TEXT_AVAILABLE:
from transformers import AutoTokenizer


class SummarizationPostprocess(Postprocess):

def __init__(self):
super().__init__()

if not _TEXT_AVAILABLE:
raise ModuleNotFoundError("Please, pip install -e '.[text]'")

self._backbone = None
self._tokenizer = None

@property
def backbone(self):
backbone_state = self.get_state(Seq2SeqBackboneState)
if backbone_state is not None:
return backbone_state.backbone

@property
def tokenizer(self):
if self.backbone is not None and self.backbone != self._backbone:
self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)
self._backbone = self.backbone
return self._tokenizer

def uncollate(self, generated_tokens: Any) -> Any:
pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
pred_str = [str.strip(s) for s in pred_str]
return pred_str
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class SummarizationData(Seq2SeqData):

preprocess_cls = Seq2SeqPreprocess
postprocess_cls = SummarizationPostprocess
postprocess_cls = Seq2SeqPostprocess
3 changes: 2 additions & 1 deletion flash/text/seq2seq/translation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from typing import Callable, Dict, Optional, Union

from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess
from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess


class TranslationPreprocess(Seq2SeqPreprocess):
Expand Down Expand Up @@ -45,3 +45,4 @@ class TranslationData(Seq2SeqData):
"""Data module for Translation tasks."""

preprocess_cls = TranslationPreprocess
postprocess_cls = Seq2SeqPostprocess
4 changes: 1 addition & 3 deletions flash/text/seq2seq/translation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,7 @@ def compute(self):
return tensor(0.0, device=self.r.device)

if self.smooth:
precision_scores = torch.add(self.numerator, torch.ones(
self.n_gram
)) / torch.add(self.denominator, torch.ones(self.n_gram))
precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0)
else:
precision_scores = self.numerator / self.denominator

Expand Down
10 changes: 5 additions & 5 deletions flash/text/seq2seq/translation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TranslationTask(Seq2SeqTask):
loss_fn: Loss function for training.
optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
metrics: Metrics to compute for training and evaluation.
learning_rate: Learning rate to use for training, defaults to `3e-4`
learning_rate: Learning rate to use for training, defaults to `1e-5`
val_target_max_length: Maximum length of targets in validation. Defaults to `128`
num_beams: Number of beams to use in validation when generating predictions. Defaults to `4`
n_gram: Maximum n_grams to use in metric calculation. Defaults to `4`
Expand All @@ -41,11 +41,11 @@ def __init__(
loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam,
metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None,
learning_rate: float = 3e-4,
learning_rate: float = 1e-5,
val_target_max_length: Optional[int] = 128,
num_beams: Optional[int] = 4,
n_gram: bool = 4,
smooth: bool = False,
smooth: bool = True,
):
self.save_hyperparameters()
super().__init__(
Expand All @@ -70,11 +70,11 @@ def compute_metrics(self, generated_tokens, batch, prefix):
tgt_lns = self.tokenize_labels(batch["labels"])
# wrap targets in list as score expects a list of potential references
tgt_lns = [[reference] for reference in tgt_lns]
result = self.bleu(generated_tokens, tgt_lns)
result = self.bleu(self._postprocess.uncollate(generated_tokens), tgt_lns)
self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True)

def _ci_benchmark_fn(self, history: List[Dict[str, Any]]):
"""
This function is used only for debugging usage with CI
"""
# assert history[-1]["val_bleu_score"]
assert history[-1]["val_bleu_score"] > 0.6
2 changes: 1 addition & 1 deletion flash_examples/finetuning/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/")

backbone = "t5-small"
backbone = "Helsinki-NLP/opus-mt-en-ro"

# 2. Load the data
datamodule = TranslationData.from_csv(
Expand Down
Loading