Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NLP Components to Benchmarking #213

Merged
merged 5 commits into from
May 4, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions examples/nlp_finetuning/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import torch
import transformers
Expand All @@ -13,26 +12,24 @@
from renate.models.renate_module import RenateWrapper


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're disallowing path objects here, should we also do that in the other examples?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the use of complex types as Path is no longer supported for our *_fn functions since 0.2.0. I can check for more occurrences but I should have removed most of them before

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see occurrences of Path in most example config files. If you don't want to remove then now, we should create an issue to do it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated the remaining config files as well as the documentation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

"""Returns a DistilBert classification model."""
transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, return_dict=False
)
model = RenateWrapper(transformer_model, loss_fn=torch.nn.CrossEntropyLoss())
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model


def data_module_fn(
data_path: Union[Path, str], chunk_id: int, seed: int = defaults.SEED
) -> RenateDataModule:
def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule:
"""Returns one of two movie review datasets depending on `chunk_id`."""
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
dataset_name = "imdb" if chunk_id else "rotten_tomatoes"
data_module = HuggingfaceTextDataModule(
str(data_path),
data_path,
dataset_name=dataset_name,
tokenizer=tokenizer,
val_size=0.2,
Expand Down
16 changes: 14 additions & 2 deletions src/renate/benchmark/datasets/nlp_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datasets
import torch
import transformers
from datasets import get_dataset_infos

from renate import defaults
from renate.data.data_module import RenateDataModule
Expand Down Expand Up @@ -79,10 +80,21 @@ def __init__(
def prepare_data(self) -> None:
"""Download data."""
split_names = datasets.get_dataset_split_names(self._dataset_name)
if not "train" in split_names:
if "train" not in split_names:
raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'train' split.")
if not "test" in split_names:
if "test" not in split_names:
raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'test' split.")
available_columns = list(get_dataset_infos(self._dataset_name)["default"].features)
if self._input_column not in available_columns:
raise ValueError(
f"Input column '{self._input_column}' does not exist in {self._dataset_name}. "
f"Available columns: {available_columns}."
)
if self._target_column not in available_columns:
raise ValueError(
f"Target column '{self._target_column}' does not exist in {self._dataset_name}. "
f"Available columns: {available_columns}."
)
self._train_data = datasets.load_dataset(
self._dataset_name, split="train", cache_dir=self._data_path
)
Expand Down
50 changes: 42 additions & 8 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import torch
from torchvision.transforms import transforms
from transformers import AutoTokenizer

from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule
from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule
from renate.benchmark.models import (
MultiLayerPerceptron,
Expand All @@ -21,6 +23,7 @@
VisionTransformerL16,
VisionTransformerL32,
)
from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer
from renate.benchmark.scenarios import (
BenchmarkScenario,
ClassIncrementalScenario,
Expand Down Expand Up @@ -49,6 +52,7 @@
"VisionTransformerL16": VisionTransformerL16,
"VisionTransformerL32": VisionTransformerL32,
"VisionTransformerH14": VisionTransformerH14,
"HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer,
}


Expand All @@ -60,6 +64,7 @@ def model_fn(
num_outputs: Optional[int] = None,
num_hidden_layers: Optional[int] = None,
hidden_size: Optional[Tuple[int]] = None,
pretrained_model_name: Optional[str] = None,
) -> RenateModule:
"""Returns a model instance."""
if model_name not in models:
Expand All @@ -69,11 +74,17 @@ def model_fn(
if updater == "Avalanche-iCaRL":
model_kwargs["prediction_strategy"] = ICaRLClassificationStrategy()
if model_name == "MultiLayerPerceptron":
model_kwargs = {
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
}
model_kwargs.update(
{
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
}
)
elif model_name == "HuggingFaceTransformer":
if updater == "Avalanche-iCaRL":
raise ValueError("Transformers do not support iCaRL.")
model_kwargs["pretrained_model_name"] = pretrained_model_name
if num_outputs is not None:
model_kwargs["num_outputs"] = num_outputs
if model_state_url is None:
Expand All @@ -85,14 +96,31 @@ def model_fn(


def get_data_module(
data_path: str, dataset_name: str, val_size: float, seed: int
data_path: str,
dataset_name: str,
val_size: float,
seed: int,
pretrained_model_name: Optional[str],
input_column: Optional[str],
target_column: Optional[str],
) -> RenateDataModule:
if dataset_name in TorchVisionDataModule.dataset_dict:
return TorchVisionDataModule(
data_path, dataset_name=dataset_name, val_size=val_size, seed=seed
)
if dataset_name in ["CLEAR10", "CLEAR100"]:
return CLEARDataModule(data_path, dataset_name=dataset_name, val_size=val_size, seed=seed)
if dataset_name.startswith("hfd-"):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
return HuggingfaceTextDataModule(
data_path=data_path,
dataset_name=dataset_name[4:],
input_column=input_column,
target_column=target_column,
tokenizer=tokenizer,
val_size=val_size,
seed=seed,
)
raise ValueError(f"Unknown dataset `{dataset_name}`.")


Expand Down Expand Up @@ -191,12 +219,18 @@ def data_module_fn(
input_dim: Optional[Tuple[int]] = None,
feature_idx: Optional[int] = None,
randomness: Optional[float] = None,
pretrained_model_name: Optional[str] = None,
input_column: Optional[str] = None,
target_column: Optional[str] = None,
):
data_module = get_data_module(
data_path=data_path,
dataset_name=dataset_name,
val_size=val_size,
seed=seed,
pretrained_model_name=pretrained_model_name,
input_column=input_column,
target_column=target_column,
)
return get_scenario(
scenario_name=scenario_name,
Expand All @@ -222,7 +256,7 @@ def _get_normalize_transform(dataset_name):

def train_transform(dataset_name: str) -> Optional[transforms.Compose]:
"""Returns a transform function to be used in the training."""
if dataset_name in ["MNIST", "FashionMNIST"]:
if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"):
return None
elif dataset_name in ["CIFAR10", "CIFAR100"]:
return transforms.Compose(
Expand All @@ -237,7 +271,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]:

def test_transform(dataset_name: str) -> Optional[transforms.Normalize]:
"""Returns a transform function to be used for validation or testing."""
if dataset_name in ["MNIST", "FashionMNIST"]:
if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"):
return None
elif dataset_name in ["CIFAR10", "CIFAR100"]:
return _get_normalize_transform(dataset_name)
Expand Down
43 changes: 43 additions & 0 deletions src/renate/benchmark/models/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Dict, Optional

import torch
import torch.nn as nn
from torch import Tensor
from transformers import AutoModelForSequenceClassification

from renate.models import RenateModule


class HuggingFaceSequenceClassificationTransformer(RenateModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that necessary or could we just use a RenateWrapper, as I did in the NLP example? I'm fine with having it, as it might be a bit more convenient. But there's always a trade-off with additional code to maintain etc...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the RenateWrapper does not allow to save any additional arguments as in this case pretrained_model_name and num_outputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

"""RenateModule which wraps around Hugging Face transformers.

Args:
pretrained_model_name: Hugging Face model id.
num_outputs: Number of outputs.
loss_fn: The loss function to be optimized during the training.
"""

def __init__(
self,
pretrained_model_name: str,
num_outputs: int,
loss_fn: nn.Module = nn.CrossEntropyLoss(),
) -> None:
super().__init__(
constructor_arguments={
"pretrained_model_name": pretrained_model_name,
"num_outputs": num_outputs,
},
loss_fn=loss_fn,
)
self._model = AutoModelForSequenceClassification.from_pretrained(
pretrained_model_name, num_labels=num_outputs, return_dict=False
)

def forward(self, x: Dict[str, Tensor], task_id: Optional[str] = None) -> torch.Tensor:
return self._model(**x)[0]

def _add_task_params(self, task_id: str) -> None:
assert not len(self._tasks_params_ids), "Transformer does not work for multiple tasks."
2 changes: 1 addition & 1 deletion src/renate/memory/buffer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from collections import defaultdict
import copy
from collections import defaultdict
from typing import Callable, Dict, Optional

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/renate/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

import renate
from renate import defaults
from renate.cli.parsing_functions import to_dense_str, get_data_module_fn_kwargs
from renate.cli.parsing_functions import get_data_module_fn_kwargs, to_dense_str
from renate.utils.file import move_to_uri
from renate.utils.module import get_and_prepare_data_module, import_module
from renate.utils.syne_tune import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
"dataset": "mnist.json",
"backend": "local",
"job_name": "class-incremental-mlp-avalanche-icarl",
"expected_accuracy": [0.9981087446212769, 0.5656219124794006]
"expected_accuracy": [0.993380606174469, 0.8330068588256836]
}
2 changes: 1 addition & 1 deletion test/integration_tests/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_config(scenario_file, model_file, updater_file, dataset_file):
f"--max-time",
type=int,
default=12 * 3600,
help="Seed.",
help="Maximum execution time.",
)
args = parser.parse_args()
config_space = load_config(
Expand Down
2 changes: 1 addition & 1 deletion test/integration_tests/run_quick_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@

# Noticed different accuracy scores across Mac and GitHub Actions Workflows (which run on Linux)
# TODO see if we can align the Mac and Linux results
assert pytest.approx(test_config["expected_accuracy"]) == accuracies
assert pytest.approx(test_config["expected_accuracy"]) == accuracies, accuracies
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this syntax?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if fails, prints accuracies

Loading