diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index 906cdaae..5abe1fc8 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -1,13 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import functools import logging from typing import Any, Dict, Optional import datasets import torch import transformers +from datasets import load_dataset from renate import defaults +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data.data_module import RenateDataModule @@ -134,3 +137,169 @@ def tokenize_fn(batch): self._val_data = _InputTargetWrapper(self._val_data, self._target_column) else: self._train_data, self._val_data = self._split_train_val_data(self._train_data) + + +class MultiTextDataModule(DataIncrementalDataModule): + """ + Inspired by the dataset used in "Episodic Memory in Lifelong Language Learning" + by d’Autume et al. this is a collection of five different datasets that we call domains: + AGNews, Yelp, DBPedia and Yahoo Answers. + + The output space if the union of the output space of all the domains. + The dataset has 33 classes: 4 from AGNews, 5 from Yelp, 14 from DBPedia, and 10 from Yahoo. + + The maximum allowed size for the training set is 115000 and for the test set is 7600. + Each domain will have the same fixed size. + + Args: + data_path: The path to the folder where the data files will be downloaded to. + tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. + tokenizer_kwargs: Keyword arguments passed when calling the tokenizer's ``__call__`` + function. Typical options are `max_length`, `padding` and `truncation`. + See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. If `None` is passed, this defaults to + `{"padding": "max_length", max_length: 128, truncation: True}`. + data_id: The dataset to be used + train_size: The size of the data stored as training set, must be smaller than 115000. + test_size: The size of the data stored as test set, must be smaller than 7600. + val_size: Fraction of the training data to be used for validation. + seed: Seed used to fix random number generation. + """ + + _multi_dataset_info = { + "ag_news": ["text", "label"], + "yelp_review_full": ["text", "label"], + "dbpedia_14": ["content", "label"], + "yahoo_answers_topics": ["question_title", "topic"], + } + _labels_map = { + "ag_news0": 0, + "ag_news1": 1, + "ag_news2": 2, + "ag_news3": 3, + "yelp_review_full0": 4, + "yelp_review_full1": 5, + "yelp_review_full2": 6, + "yelp_review_full3": 7, + "yelp_review_full4": 8, + "dbpedia_140": 9, + "dbpedia_141": 10, + "dbpedia_142": 11, + "dbpedia_143": 12, + "dbpedia_144": 13, + "dbpedia_145": 14, + "dbpedia_146": 15, + "dbpedia_147": 16, + "dbpedia_148": 17, + "dbpedia_149": 18, + "dbpedia_1410": 19, + "dbpedia_1411": 20, + "dbpedia_1412": 21, + "dbpedia_1413": 22, + "yahoo_answers_topics0": 23, + "yahoo_answers_topics1": 24, + "yahoo_answers_topics2": 25, + "yahoo_answers_topics3": 26, + "yahoo_answers_topics4": 27, + "yahoo_answers_topics5": 28, + "yahoo_answers_topics6": 29, + "yahoo_answers_topics7": 30, + "yahoo_answers_topics8": 31, + "yahoo_answers_topics9": 32, + } + + domains = _multi_dataset_info.keys() + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_id: str, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + train_size: int = defaults.SMALL_TRAIN_SET_SIZE, + test_size: int = defaults.SMALL_TEST_SET_SIZE, + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__(data_path=data_path, data_id=data_id, val_size=val_size, seed=seed) + + if train_size > 115000: + raise ValueError("The `train_size` must be smaller than or equal to 115000") + self._train_size = train_size + + if test_size > 7600: + raise ValueError("The `test_size` must be smaller than 7600") + self._test_size = test_size + + self._tokenizer = tokenizer + self._tokenizer_kwargs = tokenizer_kwargs or defaults.TOKENIZER_KWARGS + + if data_id not in self.domains: + raise ValueError( + f"The selected domain is not available. Select one among " f"{self.domains}" + ) + + self.data_id = data_id + + def prepare_data(self) -> None: + """Download dataset.""" + + for split in ["train", "test"] + (["validation"] if self._val_size > 0 else []): + load_dataset(self.data_id, split=split, cache_dir=self._data_path) + + def setup(self) -> None: + """Set up train, test and val datasets.""" + + rnd_gen = torch.Generator().manual_seed(self._seed) + + def preprocess(example, text_field_name, label_field_name): + return { + **self._tokenizer(example[text_field_name], **self._tokenizer_kwargs), + "label": self._labels_map[f"{self.data_id}{example[label_field_name]}"], + } + + def get_split(split_name): + dataset = load_dataset(self.data_id, split=split_name, cache_dir=self._data_path) + # the following is hack needed because the output space of the new dataset is + # the union of the output spaces of the single datasets + # HF datasets check for the max label id and we need to make sure we update that + # without this change the setup will fail with a value error (label id > max labels) + new_features = dataset.features.copy() + new_features[self._multi_dataset_info[self.data_id][1]] = datasets.ClassLabel( + num_classes=33 + ) + + dataset = dataset.cast(new_features) + + if "train" == split_name: + set_size = self._train_size + else: + set_size = self._test_size + + rnd_idx = torch.randint( + low=0, + high=len(dataset), + size=(set_size,), + generator=rnd_gen, + ).tolist() + dataset = dataset.select(indices=rnd_idx) + + dataset = dataset.map( + functools.partial( + preprocess, + text_field_name=self._multi_dataset_info[self.data_id][0], + label_field_name=self._multi_dataset_info[self.data_id][1], + ), + remove_columns=list(dataset.features), + num_proc=4, + ) + + dataset.set_format(type="torch") + + return _InputTargetWrapper(dataset) + + self._train_data = get_split("train") + self._test_data = get_split("test") + if self._val_size > 0: + self._val_data = get_split("validation") diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 27eccd04..7b31583d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -104,6 +104,8 @@ # Benchmark datasets/models TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True} +SMALL_TRAIN_SET_SIZE = 1000 +SMALL_TEST_SET_SIZE = 1000 def scheduler(config_space: Dict[str, Any], mode: str, metric: str): diff --git a/test/renate/benchmark/datasets/test_multi_data_nlp.py b/test/renate/benchmark/datasets/test_multi_data_nlp.py new file mode 100644 index 00000000..f91765c2 --- /dev/null +++ b/test/renate/benchmark/datasets/test_multi_data_nlp.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import transformers as transformers + +from renate.benchmark.datasets.nlp_datasets import MultiTextDataModule + + +@pytest.mark.skip(reason="This test create problems with the syne-tune redirect test") +def test_multi_data_nlp_small(tmpdir): + TRAIN_SIZE = 100 + TEST_SIZE = 100 + + data = MultiTextDataModule( + tmpdir, + train_size=TRAIN_SIZE, + test_size=TEST_SIZE, + data_id="ag_news", + tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"), + seed=42, + ) + + data.prepare_data() + data.setup() + + assert len(data.train_data()) == TRAIN_SIZE + assert len(data.test_data()) == TEST_SIZE + + first_input_agnews = data.train_data()[0][0]["input_ids"] + + data.data_id = "dbpedia_14" + data.setup() + + tr_data_dbpedia = data.train_data() + te_data_dbpedia = data.test_data() + assert len(tr_data_dbpedia) == TRAIN_SIZE + assert len(te_data_dbpedia) == TEST_SIZE + + first_input_dbpedia = data.train_data()[0][0]["input_ids"] + + assert not torch.all(torch.eq(first_input_dbpedia, first_input_agnews)) + + +@pytest.mark.skip(reason="This test requires downloading and processing four datasets.") +def test_multi_data_nlp_full(tmpdir): + TRAIN_SIZE = 115000 + TEST_SIZE = 7600 + + for d in MultiTextDataModule.domains: + data = MultiTextDataModule( + tmpdir, + train_size=TRAIN_SIZE, + test_size=TEST_SIZE, + data_id=d, + tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"), + ) + + data.prepare_data() + data.setup() + + tr_data = data.train_data() + te_data = data.test_data() + assert len(tr_data) == TRAIN_SIZE + assert len(te_data) == TEST_SIZE