From 9a5205b486ac43df1287ffcfe39a72e6c3f5eab3 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 21 May 2021 22:02:35 +0100 Subject: [PATCH 1/8] Fix translation --- flash/text/seq2seq/core/data.py | 33 ++++++++++++++++++- flash/text/seq2seq/summarization/data.py | 41 ++---------------------- flash/text/seq2seq/translation/data.py | 3 +- flash/text/seq2seq/translation/model.py | 10 +++--- flash_examples/finetuning/translation.py | 4 +-- 5 files changed, 43 insertions(+), 48 deletions(-) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 08a3151a50..a5f105d1d2 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -21,7 +21,7 @@ import flash from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources -from flash.core.data.process import Preprocess +from flash.core.data.process import Postprocess, Preprocess from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _TEXT_AVAILABLE @@ -240,7 +240,38 @@ def collate(self, samples: Any) -> Tensor: return default_data_collator(samples) +class Seq2SeqPostprocess(Postprocess): + + def __init__(self): + super().__init__() + + if not _TEXT_AVAILABLE: + raise ModuleNotFoundError("Please, pip install -e '.[text]'") + + self._backbone = None + self._tokenizer = None + + @property + def backbone(self): + backbone_state = self.get_state(Seq2SeqBackboneState) + if backbone_state is not None: + return backbone_state.backbone + + @property + def tokenizer(self): + if self.backbone is not None and self.backbone != self._backbone: + self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + self._backbone = self.backbone + return self._tokenizer + + def uncollate(self, generated_tokens: Any) -> Any: + pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + pred_str = [str.strip(s) for s in pred_str] + return pred_str + + class Seq2SeqData(DataModule): """Data module for Seq2Seq tasks.""" preprocess_cls = Seq2SeqPreprocess + postprocess_cls = Seq2SeqPostprocess diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index fd6823a0f4..31f81f9ffa 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -11,47 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any - -from flash.core.data.process import Postprocess -from flash.core.utilities.imports import _TEXT_AVAILABLE -from flash.text.seq2seq.core.data import Seq2SeqBackboneState, Seq2SeqData, Seq2SeqPreprocess - -if _TEXT_AVAILABLE: - from transformers import AutoTokenizer - - -class SummarizationPostprocess(Postprocess): - - def __init__(self): - super().__init__() - - if not _TEXT_AVAILABLE: - raise ModuleNotFoundError("Please, pip install -e '.[text]'") - - self._backbone = None - self._tokenizer = None - - @property - def backbone(self): - backbone_state = self.get_state(Seq2SeqBackboneState) - if backbone_state is not None: - return backbone_state.backbone - - @property - def tokenizer(self): - if self.backbone is not None and self.backbone != self._backbone: - self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) - self._backbone = self.backbone - return self._tokenizer - - def uncollate(self, generated_tokens: Any) -> Any: - pred_str = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - pred_str = [str.strip(s) for s in pred_str] - return pred_str +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess class SummarizationData(Seq2SeqData): preprocess_cls = Seq2SeqPreprocess - postprocess_cls = SummarizationPostprocess + postprocess_cls = Seq2SeqPostprocess diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index 057ce41869..0b9e7a3ce7 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Callable, Dict, Optional, Union -from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPreprocess +from flash.text.seq2seq.core.data import Seq2SeqData, Seq2SeqPostprocess, Seq2SeqPreprocess class TranslationPreprocess(Seq2SeqPreprocess): @@ -45,3 +45,4 @@ class TranslationData(Seq2SeqData): """Data module for Translation tasks.""" preprocess_cls = TranslationPreprocess + postprocess_cls = Seq2SeqPostprocess diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index da943ec1f4..e9734cfa8f 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -28,7 +28,7 @@ class TranslationTask(Seq2SeqTask): loss_fn: Loss function for training. optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`. metrics: Metrics to compute for training and evaluation. - learning_rate: Learning rate to use for training, defaults to `3e-4` + learning_rate: Learning rate to use for training, defaults to `1e-5` val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` @@ -41,11 +41,11 @@ def __init__( loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[pl.metrics.Metric, Mapping, Sequence, None] = None, - learning_rate: float = 3e-4, + learning_rate: float = 1e-5, val_target_max_length: Optional[int] = 128, num_beams: Optional[int] = 4, n_gram: bool = 4, - smooth: bool = False, + smooth: bool = True, ): self.save_hyperparameters() super().__init__( @@ -70,11 +70,11 @@ def compute_metrics(self, generated_tokens, batch, prefix): tgt_lns = self.tokenize_labels(batch["labels"]) # wrap targets in list as score expects a list of potential references tgt_lns = [[reference] for reference in tgt_lns] - result = self.bleu(generated_tokens, tgt_lns) + result = self.bleu(self._postprocess.uncollate(generated_tokens), tgt_lns) self.log(f"{prefix}_bleu_score", result, on_step=False, on_epoch=True, prog_bar=True) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """ This function is used only for debugging usage with CI """ - # assert history[-1]["val_bleu_score"] + assert history[-1]["val_bleu_score"] > 0.6 diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index 5965b3d9fc..44a7abba19 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -20,7 +20,7 @@ # 1. Download the data download_data("https://pl-flash-data.s3.amazonaws.com/wmt_en_ro.zip", "data/") -backbone = "t5-small" +backbone = "Helsinki-NLP/opus-mt-en-ro" # 2. Load the data datamodule = TranslationData.from_csv( @@ -47,7 +47,7 @@ trainer.finetune(model, datamodule=datamodule) # 6. Test model -trainer.test(model) +trainer.test(model, datamodule=datamodule) # 7. Save it! trainer.save_checkpoint("translation_model_en_ro.pt") From 5869c4df7981a71bb298755f1517a7d76bc284e4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 21 May 2021 22:07:24 +0100 Subject: [PATCH 2/8] Fix example --- flash_examples/finetuning/translation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash_examples/finetuning/translation.py b/flash_examples/finetuning/translation.py index 44a7abba19..69b789fa7a 100644 --- a/flash_examples/finetuning/translation.py +++ b/flash_examples/finetuning/translation.py @@ -47,7 +47,7 @@ trainer.finetune(model, datamodule=datamodule) # 6. Test model -trainer.test(model, datamodule=datamodule) +trainer.test(model) # 7. Save it! trainer.save_checkpoint("translation_model_en_ro.pt") From e1cfa5dce259ba894ad473174855627362ee175a Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 21 May 2021 22:23:04 +0100 Subject: [PATCH 3/8] Fix --- flash/text/seq2seq/translation/metric.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/flash/text/seq2seq/translation/metric.py b/flash/text/seq2seq/translation/metric.py index 615b34ddf3..bd3e4fe872 100644 --- a/flash/text/seq2seq/translation/metric.py +++ b/flash/text/seq2seq/translation/metric.py @@ -81,9 +81,7 @@ def compute(self): return tensor(0.0, device=self.r.device) if self.smooth: - precision_scores = torch.add(self.numerator, torch.ones( - self.n_gram - )) / torch.add(self.denominator, torch.ones(self.n_gram)) + precision_scores = (self.numerator + 1.0) / (self.denominator + 1.0) else: precision_scores = self.numerator / self.denominator From 24fbdbad9980392fedd9059fbedf43becc2203a4 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 21 May 2021 23:04:33 +0100 Subject: [PATCH 4/8] Add metric test --- tests/text/translation/test_metric.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 tests/text/translation/test_metric.py diff --git a/tests/text/translation/test_metric.py b/tests/text/translation/test_metric.py new file mode 100644 index 0000000000..86b5784745 --- /dev/null +++ b/tests/text/translation/test_metric.py @@ -0,0 +1,25 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +from flash.text.seq2seq.translation.metric import BLEUScore + + +@pytest.mark.parametrize("smooth, expected", [(False, 0.7598), (True, 0.8091)]) +def test_bleu_score(smooth, expected): + translate_corpus = ['the cat is on the mat'.split()] + reference_corpus = [['there is a cat on the mat'.split(), 'a cat is on the mat'.split()]] + metric = BLEUScore(smooth=smooth) + assert torch.allclose(metric(translate_corpus, reference_corpus), torch.tensor(expected), 1e-4) From 6f1ade77d3f8596310638e0efcfead6883f510d1 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 28 May 2021 16:36:53 +0100 Subject: [PATCH 5/8] Fixes --- flash/text/seq2seq/core/data.py | 57 ++++++++++++++++++++++++++++++++- 1 file changed, 56 insertions(+), 1 deletion(-) diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index a5f105d1d2..28deb8d987 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -45,7 +45,8 @@ def __init__( if not _TEXT_AVAILABLE: raise ModuleNotFoundError("Please, pip install -e '.[text]'") - self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) + self.backbone = backbone + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) self.max_source_length = max_source_length self.max_target_length = max_target_length self.padding = padding @@ -71,6 +72,15 @@ def _tokenize_fn( padding=self.padding, ) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class Seq2SeqFileDataSource(Seq2SeqDataSource): @@ -112,6 +122,15 @@ def load_data( def predict_load_data(self, data: Any) -> Union['datasets.Dataset', List[Dict[str, torch.Tensor]]]: return self.load_data(data, columns=["input_ids", "attention_mask"]) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class Seq2SeqCSVDataSource(Seq2SeqFileDataSource): @@ -130,6 +149,15 @@ def __init__( padding=padding, ) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class Seq2SeqJSONDataSource(Seq2SeqFileDataSource): @@ -148,6 +176,15 @@ def __init__( padding=padding, ) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class Seq2SeqSentencesDataSource(Seq2SeqDataSource): @@ -161,6 +198,15 @@ def load_data( data = [data] return [self._tokenize_fn(s) for s in data] + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + @dataclass(unsafe_hash=True, frozen=True) class Seq2SeqBackboneState(ProcessState): @@ -269,6 +315,15 @@ def uncollate(self, generated_tokens: Any) -> Any: pred_str = [str.strip(s) for s in pred_str] return pred_str + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("_tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class Seq2SeqData(DataModule): """Data module for Seq2Seq tasks.""" From a6d8db4a28b676189e1eefa4d3aefa2fb3010d5b Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 28 May 2021 16:48:04 +0100 Subject: [PATCH 6/8] Fix pickle bug for classification --- flash/text/classification/data.py | 46 +++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index d0acf1bda7..249af45105 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -38,6 +38,7 @@ def __init__(self, backbone: str, max_length: int = 128): if not _TEXT_AVAILABLE: raise ModuleNotFoundError("Please, pip install -e '.[text]'") + self.backbone = backbone self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True) self.max_length = max_length @@ -55,6 +56,15 @@ def _transform_label(self, label_to_class_mapping: Dict[str, int], target: str, ex[target] = label_to_class_mapping[ex[target]] return ex + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class TextFileDataSource(TextDataSource): @@ -115,18 +125,45 @@ def load_data( def predict_load_data(self, data: Any, dataset: AutoDataset): return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class TextCSVDataSource(TextFileDataSource): def __init__(self, backbone: str, max_length: int = 128): super().__init__("csv", backbone, max_length=max_length) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class TextJSONDataSource(TextFileDataSource): def __init__(self, backbone: str, max_length: int = 128): super().__init__("json", backbone, max_length=max_length) + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class TextSentencesDataSource(TextDataSource): @@ -143,6 +180,15 @@ def load_data( data = [data] return [self._tokenize_fn(s, ) for s in data] + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + class TextClassificationPreprocess(Preprocess): From 5e6f7be91bc9650093bda1c70c7302609233a59c Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 28 May 2021 16:48:20 +0100 Subject: [PATCH 7/8] Update CHANGELOG.md --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 58dfd5cb70..99c916b7a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `flash.Trainer.add_argparse_args` not adding any arguments ([#343](https://github.com/PyTorchLightning/lightning-flash/pull/343)) +- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332)) + + +- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332)) + + ## [0.3.0] - 2021-05-20 ### Added From b4bfa312e57768c0996ad1d0d0cca33bd9f5857d Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Fri, 28 May 2021 17:55:06 +0100 Subject: [PATCH 8/8] Add tests --- tests/text/classification/test_data.py | 39 ++++++++++ tests/text/seq2seq/core/test_data.py | 74 +++++++++++++++++++ .../{ => seq2seq}/summarization/__init__.py | 0 .../{ => seq2seq}/summarization/test_data.py | 0 .../{ => seq2seq}/summarization/test_model.py | 0 .../{ => seq2seq}/translation/__init__.py | 0 .../{ => seq2seq}/translation/test_data.py | 0 .../{ => seq2seq}/translation/test_metric.py | 0 .../{ => seq2seq}/translation/test_model.py | 0 9 files changed, 113 insertions(+) create mode 100644 tests/text/seq2seq/core/test_data.py rename tests/text/{ => seq2seq}/summarization/__init__.py (100%) rename tests/text/{ => seq2seq}/summarization/test_data.py (100%) rename tests/text/{ => seq2seq}/summarization/test_model.py (100%) rename tests/text/{ => seq2seq}/translation/__init__.py (100%) rename tests/text/{ => seq2seq}/translation/test_data.py (100%) rename tests/text/{ => seq2seq}/translation/test_metric.py (100%) rename tests/text/{ => seq2seq}/translation/test_model.py (100%) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 19a4186389..0564355cd6 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -18,6 +18,16 @@ from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassificationData +from flash.text.classification.data import ( + TextCSVDataSource, + TextDataSource, + TextFileDataSource, + TextJSONDataSource, + TextSentencesDataSource, +) + +if _TEXT_AVAILABLE: + from transformers.tokenization_utils_base import PreTrainedTokenizerBase TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing @@ -92,3 +102,32 @@ def test_from_json(tmpdir): def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): TextClassificationData.from_json("sentence", "lab", backbone=TEST_BACKBONE, train_file="", batch_size=1) + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.parametrize( + "cls, kwargs", + [ + (TextDataSource, {}), + (TextFileDataSource, { + "filetype": "csv" + }), + (TextCSVDataSource, {}), + (TextJSONDataSource, {}), + (TextSentencesDataSource, {}), + ], +) +def test_tokenizer_state(cls, kwargs): + """Tests that the tokenizer is not in __getstate__""" + instance = cls(backbone="sshleifer/tiny-mbart", **kwargs) + state = instance.__getstate__() + tokenizers = [] + for name, attribute in instance.__dict__.items(): + if isinstance(attribute, PreTrainedTokenizerBase): + assert name not in state + setattr(instance, name, None) + tokenizers.append(name) + instance.__setstate__(state) + for name in tokenizers: + assert getattr(instance, name, None) is not None diff --git a/tests/text/seq2seq/core/test_data.py b/tests/text/seq2seq/core/test_data.py new file mode 100644 index 0000000000..bf63bb86fb --- /dev/null +++ b/tests/text/seq2seq/core/test_data.py @@ -0,0 +1,74 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import Path + +import pytest + +from flash.core.utilities.imports import _TEXT_AVAILABLE +from flash.text import TextClassificationData +from flash.text.seq2seq.core.data import ( + Seq2SeqBackboneState, + Seq2SeqCSVDataSource, + Seq2SeqDataSource, + Seq2SeqFileDataSource, + Seq2SeqJSONDataSource, + Seq2SeqPostprocess, + Seq2SeqSentencesDataSource, +) + +if _TEXT_AVAILABLE: + from transformers.tokenization_utils_base import PreTrainedTokenizerBase + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_AVAILABLE, reason="text libraries aren't installed.") +@pytest.mark.parametrize( + "cls, kwargs", + [ + (Seq2SeqDataSource, { + "backbone": "sshleifer/tiny-mbart" + }), + (Seq2SeqFileDataSource, { + "backbone": "sshleifer/tiny-mbart", + "filetype": "csv" + }), + (Seq2SeqCSVDataSource, { + "backbone": "sshleifer/tiny-mbart" + }), + (Seq2SeqJSONDataSource, { + "backbone": "sshleifer/tiny-mbart" + }), + (Seq2SeqSentencesDataSource, { + "backbone": "sshleifer/tiny-mbart" + }), + (Seq2SeqPostprocess, {}), + ], +) +def test_tokenizer_state(cls, kwargs): + """Tests that the tokenizer is not in __getstate__""" + process_state = Seq2SeqBackboneState(backbone="sshleifer/tiny-mbart") + instance = cls(**kwargs) + instance.set_state(process_state) + getattr(instance, "tokenizer", None) + state = instance.__getstate__() + tokenizers = [] + for name, attribute in instance.__dict__.items(): + if isinstance(attribute, PreTrainedTokenizerBase): + assert name not in state + setattr(instance, name, None) + tokenizers.append(name) + instance.__setstate__(state) + for name in tokenizers: + assert getattr(instance, name, None) is not None diff --git a/tests/text/summarization/__init__.py b/tests/text/seq2seq/summarization/__init__.py similarity index 100% rename from tests/text/summarization/__init__.py rename to tests/text/seq2seq/summarization/__init__.py diff --git a/tests/text/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py similarity index 100% rename from tests/text/summarization/test_data.py rename to tests/text/seq2seq/summarization/test_data.py diff --git a/tests/text/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py similarity index 100% rename from tests/text/summarization/test_model.py rename to tests/text/seq2seq/summarization/test_model.py diff --git a/tests/text/translation/__init__.py b/tests/text/seq2seq/translation/__init__.py similarity index 100% rename from tests/text/translation/__init__.py rename to tests/text/seq2seq/translation/__init__.py diff --git a/tests/text/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py similarity index 100% rename from tests/text/translation/test_data.py rename to tests/text/seq2seq/translation/test_data.py diff --git a/tests/text/translation/test_metric.py b/tests/text/seq2seq/translation/test_metric.py similarity index 100% rename from tests/text/translation/test_metric.py rename to tests/text/seq2seq/translation/test_metric.py diff --git a/tests/text/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py similarity index 100% rename from tests/text/translation/test_model.py rename to tests/text/seq2seq/translation/test_model.py