Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add benchmark made of multiple text datasets #354

Merged
merged 38 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
19ee3d5
WIP implementation of multi-dataset NLP benchmark
610v4nn1 Jul 25, 2023
a00ed47
First implementation of 5 datasets NLP benchmark
610v4nn1 Jul 25, 2023
fe325e8
add test for new data module
610v4nn1 Jul 25, 2023
b21c0b4
Made implementation thread safe. Add dataset selection to speedup loa…
610v4nn1 Jul 27, 2023
0dd1618
fix labels map
610v4nn1 Jul 27, 2023
2198d2a
Add quick test
610v4nn1 Jul 27, 2023
c4404a9
fix docstring and defaults. remove unused code.
610v4nn1 Jul 27, 2023
15c0745
fix typo
610v4nn1 Jul 27, 2023
cb4d723
make flake happy
610v4nn1 Jul 27, 2023
ef956b3
skip long test
610v4nn1 Jul 27, 2023
2d3de75
import pytest
610v4nn1 Jul 27, 2023
5eac7fc
skip test
610v4nn1 Jul 31, 2023
d5c8333
rename domain
610v4nn1 Jul 31, 2023
4048350
import pytest
610v4nn1 Jul 31, 2023
6f81901
remove amazon reviews
610v4nn1 Aug 1, 2023
015ef3f
WIP implementation of multi-dataset NLP benchmark
610v4nn1 Jul 25, 2023
dc5d80e
First implementation of 5 datasets NLP benchmark
610v4nn1 Jul 25, 2023
63ca3a6
add test for new data module
610v4nn1 Jul 25, 2023
0b9a90b
Made implementation thread safe. Add dataset selection to speedup loa…
610v4nn1 Jul 27, 2023
1bfbc4d
fix labels map
610v4nn1 Jul 27, 2023
7bc816b
Add quick test
610v4nn1 Jul 27, 2023
678a4a7
fix docstring and defaults. remove unused code.
610v4nn1 Jul 27, 2023
d31a13a
fix typo
610v4nn1 Jul 27, 2023
0bd72ec
make flake happy
610v4nn1 Jul 27, 2023
bf2f8f2
skip long test
610v4nn1 Jul 27, 2023
cb6fe3b
import pytest
610v4nn1 Jul 27, 2023
df4d172
skip test
610v4nn1 Jul 31, 2023
5060e79
rename domain
610v4nn1 Jul 31, 2023
d001151
import pytest
610v4nn1 Jul 31, 2023
e7985ec
remove amazon reviews
610v4nn1 Aug 1, 2023
ca1c526
fix order agnews labels
610v4nn1 Aug 1, 2023
9f11751
fix test skip reason
610v4nn1 Aug 1, 2023
56bb55a
improve tests, adapt to data incremental module
610v4nn1 Aug 3, 2023
8452b84
fix and merge
610v4nn1 Aug 3, 2023
9c26c1a
add seed randint
610v4nn1 Aug 3, 2023
c82285e
fix exception message
610v4nn1 Aug 3, 2023
3b009a7
avoid copying metadata to change num classes
610v4nn1 Aug 3, 2023
1ad978e
fix generator and features
610v4nn1 Aug 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions src/renate/benchmark/datasets/nlp_datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import functools
import logging
from typing import Any, Dict, Optional

import datasets
import torch
import transformers
from datasets import load_dataset

from renate import defaults
from renate.benchmark.datasets.base import DataIncrementalDataModule
from renate.data.data_module import RenateDataModule


Expand Down Expand Up @@ -134,3 +137,169 @@ def tokenize_fn(batch):
self._val_data = _InputTargetWrapper(self._val_data, self._target_column)
else:
self._train_data, self._val_data = self._split_train_val_data(self._train_data)


class MultiTextDataModule(DataIncrementalDataModule):
"""
Inspired by the dataset used in "Episodic Memory in Lifelong Language Learning"
by d’Autume et al. this is a collection of five different datasets that we call domains:
AGNews, Yelp, DBPedia and Yahoo Answers.

The output space if the union of the output space of all the domains.
The dataset has 33 classes: 4 from AGNews, 5 from Yelp, 14 from DBPedia, and 10 from Yahoo.

The maximum allowed size for the training set is 115000 and for the test set is 7600.
Each domain will have the same fixed size.

Args:
data_path: The path to the folder where the data files will be downloaded to.
tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/
for more information on tokenizers.
tokenizer_kwargs: Keyword arguments passed when calling the tokenizer's ``__call__``
function. Typical options are `max_length`, `padding` and `truncation`.
See https://huggingface.co/docs/tokenizers/
for more information on tokenizers. If `None` is passed, this defaults to
`{"padding": "max_length", max_length: 128, truncation: True}`.
data_id: The dataset to be used
train_size: The size of the data stored as training set, must be smaller than 115000.
test_size: The size of the data stored as test set, must be smaller than 7600.
610v4nn1 marked this conversation as resolved.
Show resolved Hide resolved
val_size: Fraction of the training data to be used for validation.
seed: Seed used to fix random number generation.
"""

_multi_dataset_info = {
"ag_news": ["text", "label"],
"yelp_review_full": ["text", "label"],
"dbpedia_14": ["content", "label"],
"yahoo_answers_topics": ["question_title", "topic"],
}
_labels_map = {
"ag_news0": 0,
"ag_news1": 1,
"ag_news2": 2,
"ag_news3": 3,
"yelp_review_full0": 4,
"yelp_review_full1": 5,
"yelp_review_full2": 6,
"yelp_review_full3": 7,
"yelp_review_full4": 8,
"dbpedia_140": 9,
"dbpedia_141": 10,
"dbpedia_142": 11,
"dbpedia_143": 12,
"dbpedia_144": 13,
"dbpedia_145": 14,
"dbpedia_146": 15,
"dbpedia_147": 16,
"dbpedia_148": 17,
"dbpedia_149": 18,
"dbpedia_1410": 19,
"dbpedia_1411": 20,
"dbpedia_1412": 21,
"dbpedia_1413": 22,
"yahoo_answers_topics0": 23,
"yahoo_answers_topics1": 24,
"yahoo_answers_topics2": 25,
"yahoo_answers_topics3": 26,
"yahoo_answers_topics4": 27,
"yahoo_answers_topics5": 28,
"yahoo_answers_topics6": 29,
"yahoo_answers_topics7": 30,
"yahoo_answers_topics8": 31,
"yahoo_answers_topics9": 32,
}

domains = _multi_dataset_info.keys()

def __init__(
self,
data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_id: str,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
train_size: int = defaults.SMALL_TRAIN_SET_SIZE,
test_size: int = defaults.SMALL_TEST_SET_SIZE,
Comment on lines +220 to +221
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the intuition of selecting a subset for this specific dataset?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set a relatively small value by default because I expect it to be closer to the actual usage than the max value

val_size: float = defaults.VALIDATION_SIZE,
seed: int = defaults.SEED,
):
super().__init__(data_path=data_path, data_id=data_id, val_size=val_size, seed=seed)

if train_size > 115000:
raise ValueError("The `train_size` must be smaller than or equal to 115000")
self._train_size = train_size

if test_size > 7600:
raise ValueError("The `test_size` must be smaller than 7600")
self._test_size = test_size

self._tokenizer = tokenizer
self._tokenizer_kwargs = tokenizer_kwargs or defaults.TOKENIZER_KWARGS

if data_id not in self.domains:
raise ValueError(
f"The selected domain is not available. Select one among " f"{self.domains}"
)

self.data_id = data_id

def prepare_data(self) -> None:
610v4nn1 marked this conversation as resolved.
Show resolved Hide resolved
"""Download dataset."""

for split in ["train", "test"] + (["validation"] if self._val_size > 0 else []):
load_dataset(self.data_id, split=split, cache_dir=self._data_path)

def setup(self) -> None:
"""Set up train, test and val datasets."""

rnd_gen = torch.Generator().manual_seed(self._seed)

def preprocess(example, text_field_name, label_field_name):
return {
**self._tokenizer(example[text_field_name], **self._tokenizer_kwargs),
"label": self._labels_map[f"{self.data_id}{example[label_field_name]}"],
}

def get_split(split_name):
dataset = load_dataset(self.data_id, split=split_name, cache_dir=self._data_path)
# the following is hack needed because the output space of the new dataset is
# the union of the output spaces of the single datasets
# HF datasets check for the max label id and we need to make sure we update that
# without this change the setup will fail with a value error (label id > max labels)
new_features = dataset.features.copy()
new_features[self._multi_dataset_info[self.data_id][1]] = datasets.ClassLabel(
num_classes=33
)

dataset = dataset.cast(new_features)

if "train" == split_name:
set_size = self._train_size
else:
set_size = self._test_size

rnd_idx = torch.randint(
low=0,
high=len(dataset),
size=(set_size,),
generator=rnd_gen,
).tolist()
dataset = dataset.select(indices=rnd_idx)

dataset = dataset.map(
functools.partial(
preprocess,
text_field_name=self._multi_dataset_info[self.data_id][0],
label_field_name=self._multi_dataset_info[self.data_id][1],
),
remove_columns=list(dataset.features),
num_proc=4,
)

dataset.set_format(type="torch")

return _InputTargetWrapper(dataset)

self._train_data = get_split("train")
self._test_data = get_split("test")
if self._val_size > 0:
self._val_data = get_split("validation")
2 changes: 2 additions & 0 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@

# Benchmark datasets/models
TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True}
SMALL_TRAIN_SET_SIZE = 1000
SMALL_TEST_SET_SIZE = 1000
Comment on lines +107 to +108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names still imply that they are generally used. I was thinking more something along the lines of MULTI_TEXT_TRAIN_SET_SIZE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer no to have per-dataset default training/test set size



def scheduler(config_space: Dict[str, Any], mode: str, metric: str):
Expand Down
65 changes: 65 additions & 0 deletions test/renate/benchmark/datasets/test_multi_data_nlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
610v4nn1 marked this conversation as resolved.
Show resolved Hide resolved
# SPDX-License-Identifier: Apache-2.0
import pytest
import torch
import transformers as transformers

from renate.benchmark.datasets.nlp_datasets import MultiTextDataModule


@pytest.mark.skip(reason="This test create problems with the syne-tune redirect test")
def test_multi_data_nlp_small(tmpdir):
TRAIN_SIZE = 100
TEST_SIZE = 100

data = MultiTextDataModule(
tmpdir,
train_size=TRAIN_SIZE,
test_size=TEST_SIZE,
data_id="ag_news",
tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"),
seed=42,
)

data.prepare_data()
data.setup()

assert len(data.train_data()) == TRAIN_SIZE
assert len(data.test_data()) == TEST_SIZE

first_input_agnews = data.train_data()[0][0]["input_ids"]

data.data_id = "dbpedia_14"
data.setup()

tr_data_dbpedia = data.train_data()
te_data_dbpedia = data.test_data()
assert len(tr_data_dbpedia) == TRAIN_SIZE
assert len(te_data_dbpedia) == TEST_SIZE
610v4nn1 marked this conversation as resolved.
Show resolved Hide resolved

first_input_dbpedia = data.train_data()[0][0]["input_ids"]

assert not torch.all(torch.eq(first_input_dbpedia, first_input_agnews))


@pytest.mark.skip(reason="This test requires downloading and processing four datasets.")
def test_multi_data_nlp_full(tmpdir):
TRAIN_SIZE = 115000
TEST_SIZE = 7600

for d in MultiTextDataModule.domains:
data = MultiTextDataModule(
tmpdir,
train_size=TRAIN_SIZE,
test_size=TEST_SIZE,
data_id=d,
tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"),
)

data.prepare_data()
data.setup()

tr_data = data.train_data()
te_data = data.test_data()
assert len(tr_data) == TRAIN_SIZE
assert len(te_data) == TEST_SIZE
Loading