Skip to content

Commit

Permalink
Add NLP Components to Benchmarking (#213)
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba authored and lballes committed May 24, 2023
1 parent aeae35c commit 83d2123
Show file tree
Hide file tree
Showing 18 changed files with 267 additions and 68 deletions.
2 changes: 1 addition & 1 deletion doc/benchmarking/custom_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ with your data module as follows.
.. code-block:: python
def data_module_fn(
data_path: Union[Path, str], chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]]
data_path: str, chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]]
):
data_module = CustomDataModule(data_path=data_path, seed=seed)
return ClassIncrementalScenario(
Expand Down
8 changes: 8 additions & 0 deletions doc/benchmarking/renate_benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ The full list of models and model names including a short description is provide
* - `~renate.benchmark.models.vision_transformer.VisionTransformerH14`
- Huge `Vision Transformer <https://arxiv.org/abs/2010.11929>`_ architecture for images of size 224x224 with patch size 14.
- * ``num_outputs``: Output dimensionality, for classification the number of classes.
* - `~renate.benchmark.models.transformer.HuggingFaceSequenceClassificationTransformer`
- Wrapper around Hugging Face transformers.
- * ``pretrained_model_name``: Hugging Face `transformer ID <https://huggingface.co/models>`__.
* ``num_outputs``: The number of classes.


.. _benchmarking-renate-benchmarks-datasets:
Expand Down Expand Up @@ -133,6 +137,10 @@ The following table contains the list of supported datasets.
- Image Classification
- 60k train, 10k test, 10 classes, image shape 28x28x1
- Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012.
* - hfd-{dataset_name}
- multiple
- Any `Hugging Face dataset <https://huggingface.co/datasets>`__ can be used. Just prepend the prefix ``hfd-``, e.g., ``hfd-rotten_tomatoes``. Select input and target columns via ``config_space``, e.g., add ``"input_column": "text", "target_column": "label"`` for the `rotten_tomatoes <https://huggingface.co/datasets/rotten_tomatoes>`__ example.
- Please refer to `the official documentation <https://huggingface.co/datasets>`__.

.. _benchmarking-renate-benchmarks-scenarios:

Expand Down
18 changes: 9 additions & 9 deletions doc/getting_started/how_to_renate_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Its signature is

.. code-block:: python
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
A :py:class:`~renate.models.renate_module.RenateModule` is a :code:`torch.nn.Module` with some
additional functionality relevant to continual learning.
Expand All @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters.

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 13-39
:lines: 14-39

If you are using a torch model with **no or fixed hyperparameters**, you can use
:py:class:`~renate.models.renate_module.RenateWrapper`.
Expand All @@ -40,7 +40,7 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`.
.. code-block:: python
:caption: Example
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10) # Instantiate your torch model.
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
Expand All @@ -58,7 +58,7 @@ Its signature is

.. code-block:: python
def data_module_fn(data_path: Union[Path, str], seed: int = defaults.SEED) -> RenateDataModule:
def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule:
:py:class:`~renate.data.data_module.RenateDataModule` provides a structured interface to
download, set up, and access train/val/test datasets.
Expand All @@ -67,7 +67,7 @@ such as data subsampling or splitting.

.. literalinclude:: ../../examples/getting_started/renate_config.py
:caption: Example
:lines: 42-72
:lines: 43-69

Transforms
==========
Expand Down Expand Up @@ -143,11 +143,11 @@ Let us assume we already have a config file in which we implemented a simple lin

.. code-block:: python
def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(28 * 28, 10)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
Expand All @@ -156,11 +156,11 @@ The natural change would be to change it to something like

.. code-block:: python
def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule:
my_torch_model = torch.nn.Linear(num_inputs, num_outputs)
model = RenateWrapper(my_torch_model)
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model
Expand Down
10 changes: 5 additions & 5 deletions doc/getting_started/output.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ In the following, we refer with :code:`model_fn` to the function defined by the
Output Saved Locally
~~~~~~~~~~~~~~~~~~~~

If :code:`next_state_url` is a path to a local folder, loading the updated model can be done as follows:
If :code:`output_state_url` is a path to a local folder, loading the updated model can be done as follows:

.. code-block:: python
from renate.defaults import current_state_folder, model_file
from renate.defaults import input_state_folder, model_file
my_model = model_fn(model_file(current_state_folder(next_state_url)))
my_model = model_fn(model_file(input_state_folder(output_state_url)))
Output Saved on S3
~~~~~~~~~~~~~~~~~~
Expand All @@ -46,9 +46,9 @@ If the Renate output was saved on S3, the model checkpoint :code:`model.ckpt` ca

.. code-block:: python
from renate.defaults import current_state_folder, model_file
from renate.defaults import input_state_folder, model_file
print(model_file(current_state_folder(next_state_url)))
print(model_file(input_state_folder(output_state_url)))
and then loaded via

Expand Down
13 changes: 5 additions & 8 deletions examples/nlp_finetuning/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import torch
import transformers
Expand All @@ -13,26 +12,24 @@
from renate.models.renate_module import RenateWrapper


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
"""Returns a DistilBert classification model."""
transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2, return_dict=False
)
model = RenateWrapper(transformer_model, loss_fn=torch.nn.CrossEntropyLoss())
if model_state_url is not None:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model.load_state_dict(state_dict)
return model


def data_module_fn(
data_path: Union[Path, str], chunk_id: int, seed: int = defaults.SEED
) -> RenateDataModule:
def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule:
"""Returns one of two movie review datasets depending on `chunk_id`."""
tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
dataset_name = "imdb" if chunk_id else "rotten_tomatoes"
data_module = HuggingfaceTextDataModule(
str(data_path),
data_path,
dataset_name=dataset_name,
tokenizer=tokenizer,
val_size=0.2,
Expand Down
7 changes: 3 additions & 4 deletions examples/simple_classifier_cifar10/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional

import torch
from torchvision import transforms
Expand All @@ -13,12 +12,12 @@
from renate.models import RenateModule


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
"""Returns a model instance."""
if model_state_url is None:
model = ResNet18CIFAR()
else:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model = ResNet18CIFAR.from_state_dict(state_dict)
return model

Expand Down
7 changes: 3 additions & 4 deletions examples/train_mlp_locally/renate_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from typing import Callable, Optional

import torch
from torchvision.transforms import transforms
Expand Down Expand Up @@ -34,14 +33,14 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) ->
return class_incremental_scenario


def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule:
def model_fn(model_state_url: Optional[str] = None) -> RenateModule:
"""Returns a model instance."""
if model_state_url is None:
model = MultiLayerPerceptron(
num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128
)
else:
state_dict = torch.load(str(model_state_url))
state_dict = torch.load(model_state_url)
model = MultiLayerPerceptron.from_state_dict(state_dict)
return model

Expand Down
16 changes: 14 additions & 2 deletions src/renate/benchmark/datasets/nlp_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import datasets
import torch
import transformers
from datasets import get_dataset_infos

from renate import defaults
from renate.data.data_module import RenateDataModule
Expand Down Expand Up @@ -79,10 +80,21 @@ def __init__(
def prepare_data(self) -> None:
"""Download data."""
split_names = datasets.get_dataset_split_names(self._dataset_name)
if not "train" in split_names:
if "train" not in split_names:
raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'train' split.")
if not "test" in split_names:
if "test" not in split_names:
raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'test' split.")
available_columns = list(get_dataset_infos(self._dataset_name)["default"].features)
if self._input_column not in available_columns:
raise ValueError(
f"Input column '{self._input_column}' does not exist in {self._dataset_name}. "
f"Available columns: {available_columns}."
)
if self._target_column not in available_columns:
raise ValueError(
f"Target column '{self._target_column}' does not exist in {self._dataset_name}. "
f"Available columns: {available_columns}."
)
self._train_data = datasets.load_dataset(
self._dataset_name, split="train", cache_dir=self._data_path
)
Expand Down
50 changes: 42 additions & 8 deletions src/renate/benchmark/experiment_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import torch
from torchvision.transforms import transforms
from transformers import AutoTokenizer

from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule
from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule
from renate.benchmark.models import (
MultiLayerPerceptron,
Expand All @@ -21,6 +23,7 @@
VisionTransformerL16,
VisionTransformerL32,
)
from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer
from renate.benchmark.scenarios import (
BenchmarkScenario,
ClassIncrementalScenario,
Expand Down Expand Up @@ -49,6 +52,7 @@
"VisionTransformerL16": VisionTransformerL16,
"VisionTransformerL32": VisionTransformerL32,
"VisionTransformerH14": VisionTransformerH14,
"HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer,
}


Expand All @@ -60,6 +64,7 @@ def model_fn(
num_outputs: Optional[int] = None,
num_hidden_layers: Optional[int] = None,
hidden_size: Optional[Tuple[int]] = None,
pretrained_model_name: Optional[str] = None,
) -> RenateModule:
"""Returns a model instance."""
if model_name not in models:
Expand All @@ -69,11 +74,17 @@ def model_fn(
if updater == "Avalanche-iCaRL":
model_kwargs["prediction_strategy"] = ICaRLClassificationStrategy()
if model_name == "MultiLayerPerceptron":
model_kwargs = {
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
}
model_kwargs.update(
{
"num_inputs": num_inputs,
"num_hidden_layers": num_hidden_layers,
"hidden_size": hidden_size,
}
)
elif model_name == "HuggingFaceTransformer":
if updater == "Avalanche-iCaRL":
raise ValueError("Transformers do not support iCaRL.")
model_kwargs["pretrained_model_name"] = pretrained_model_name
if num_outputs is not None:
model_kwargs["num_outputs"] = num_outputs
if model_state_url is None:
Expand All @@ -85,14 +96,31 @@ def model_fn(


def get_data_module(
data_path: str, dataset_name: str, val_size: float, seed: int
data_path: str,
dataset_name: str,
val_size: float,
seed: int,
pretrained_model_name: Optional[str],
input_column: Optional[str],
target_column: Optional[str],
) -> RenateDataModule:
if dataset_name in TorchVisionDataModule.dataset_dict:
return TorchVisionDataModule(
data_path, dataset_name=dataset_name, val_size=val_size, seed=seed
)
if dataset_name in ["CLEAR10", "CLEAR100"]:
return CLEARDataModule(data_path, dataset_name=dataset_name, val_size=val_size, seed=seed)
if dataset_name.startswith("hfd-"):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
return HuggingfaceTextDataModule(
data_path=data_path,
dataset_name=dataset_name[4:],
input_column=input_column,
target_column=target_column,
tokenizer=tokenizer,
val_size=val_size,
seed=seed,
)
raise ValueError(f"Unknown dataset `{dataset_name}`.")


Expand Down Expand Up @@ -191,12 +219,18 @@ def data_module_fn(
input_dim: Optional[Tuple[int]] = None,
feature_idx: Optional[int] = None,
randomness: Optional[float] = None,
pretrained_model_name: Optional[str] = None,
input_column: Optional[str] = None,
target_column: Optional[str] = None,
):
data_module = get_data_module(
data_path=data_path,
dataset_name=dataset_name,
val_size=val_size,
seed=seed,
pretrained_model_name=pretrained_model_name,
input_column=input_column,
target_column=target_column,
)
return get_scenario(
scenario_name=scenario_name,
Expand All @@ -222,7 +256,7 @@ def _get_normalize_transform(dataset_name):

def train_transform(dataset_name: str) -> Optional[transforms.Compose]:
"""Returns a transform function to be used in the training."""
if dataset_name in ["MNIST", "FashionMNIST"]:
if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"):
return None
elif dataset_name in ["CIFAR10", "CIFAR100"]:
return transforms.Compose(
Expand All @@ -237,7 +271,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]:

def test_transform(dataset_name: str) -> Optional[transforms.Normalize]:
"""Returns a transform function to be used for validation or testing."""
if dataset_name in ["MNIST", "FashionMNIST"]:
if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"):
return None
elif dataset_name in ["CIFAR10", "CIFAR100"]:
return _get_normalize_transform(dataset_name)
Expand Down
Loading

0 comments on commit 83d2123

Please sign in to comment.