From 5c0e5a33879336deb633cb1e8d47a7225fdb9d73 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Apr 2023 14:43:05 +0200 Subject: [PATCH 01/89] Bump Torch (#200) --- .github/workflows/run_renate.yml | 7 +++---- pyproject.toml | 2 ++ requirements.txt | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_renate.yml b/.github/workflows/run_renate.yml index 88013038..d8fcfab0 100644 --- a/.github/workflows/run_renate.yml +++ b/.github/workflows/run_renate.yml @@ -41,16 +41,15 @@ jobs: role-session-name: integtestsession aws-region: ${{ env.AWS_DEFAULT_REGION }} - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.9 cache: 'pip' - name: Install Renate run: | python -m pip install --upgrade pip - python -m pip install -e '.' - python -m pip install pytest avalanche-lib + python -m pip install -e '.[dev]' - name: Run optional custom command if: ${{ inputs.additional-command != '' }} run: ${{ inputs.additional-command }} diff --git a/pyproject.toml b/pyproject.toml index ca91ecb1..d8a71946 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,12 @@ dynamic = ["version", "readme", "dependencies"] [project.optional-dependencies] avalanche = [ "avalanche_lib==0.3.1", + "torch>=1.10.0, <1.12.2", ] dev = [ "black==23.1.0", "avalanche_lib==0.3.1", + "torch>=1.10.0, <1.12.2", # PyTest Dependencies "pytest==7.2.2", "pytest-cov==4.0.0", diff --git a/requirements.txt b/requirements.txt index d8b5f150..ff626d22 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ numpy>=1.17.2, <1.24.2 -torch>=1.10.0, <1.12.2 +torch>=1.10.0, <1.13.2 pandas>=1.4.0, <1.5.3 boto3>=1.26.0, <1.26.116 requests>=2.28.0, <2.28.2 From 99719591dce55333bf03062cd87775df977bf5e6 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Apr 2023 18:14:35 +0200 Subject: [PATCH 02/89] Describe Usage of Avalanche updaters and their Limitations (#202) --- doc/getting_started/avalanche.rst | 53 ++++++++++++++++++++ doc/getting_started/index.rst | 3 +- doc/getting_started/supported_algorithms.rst | 14 +++++- 3 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 doc/getting_started/avalanche.rst diff --git a/doc/getting_started/avalanche.rst b/doc/getting_started/avalanche.rst new file mode 100644 index 00000000..bcf61b02 --- /dev/null +++ b/doc/getting_started/avalanche.rst @@ -0,0 +1,53 @@ +Avalanche - Usage and Limitations +********************************* + +`Avalanche `__ is a popular continual learning library for rapid prototyping +and benchmarking algorithms. For that reason, we make some Avalanche updaters available in Renate. Avalanche updaters +can be used in the same way you can use other Renate updaters but they have some limitations you should be aware of. + +Usage +===== +Using Avalanche updaters for training works the same as with Renate updaters which we explained earlier in +:doc:`how_to_run_training`. You can select an Avalanche updater by passing the respective string to the ``updater`` +argument of :py:func:`~renate.training.training.run_training_job`. The available Avalanche options are +``"Avalanche-ER"``, ``"Avalanche-EWC"``, ``"Avalanche-LwF"``, and ``"Avalanche-iCaRL"``. +More details about these algorithms are given in :doc:`supported_algorithms`. +Further Avalanche updaters can be created if needed. + +Limitations +=========== +Not all Renate features work with every Avalanche updater. In the following, we list the limitations you face when +using an Avalanche updater. These limitations are subject to change. + +PyTorch 1.13.1 +-------------- +The checkpointing functionality in Avalanche does not work with the latest version of PyTorch 1. Therefore, you will +be required to use PyTorch 1.12.1 instead. + +.. warning:: + There is a `known vulnerability `__ with PyTorch <= 1.13.0. + +No Scalable Buffers +------------------- +Renate stores the memory buffer on disk. In contrast, Avalanche requires it to be in memory. Therefore, Avalanche +updaters may not work if you intend to use very large buffer sizes. + +No Multi-Fidelity HPO +--------------------- +Currently, we do not support multi-fidelity hyperparameter optimization with Avalanche updaters. For that reason, +please do not use ``asha`` as a scheduler but use ``random`` or ``bo`` instead. +For more details about HPO in Renate, please refer to the +:ref:`Renate HPO section `. + +No Early-Stopping +----------------- +Currently, Avalanche updaters will not work with early stopping. Please keep ``early_stopping=True`` (default setting). + +iCaRL Limitations +----------------- +The implementation of iCaRL makes strong assumptions about the continual learning scenario. Only classification is +supported and it assumes that data points of a particular class only occur in one update. If this is not the case, +iCaRL will crash as soon as you attempt an update with a class seen before. Furthermore, it requires a specific +model interface to account for its strategy. For that purpose, please create a model class which extends +:py:class:`RenateBenchmarkingModule ` or copy the relevant +parts over to your ``RenateModule``. diff --git a/doc/getting_started/index.rst b/doc/getting_started/index.rst index 0ca3f868..ac344f0a 100644 --- a/doc/getting_started/index.rst +++ b/doc/getting_started/index.rst @@ -4,7 +4,7 @@ Getting Started This section covers the usage of Renate from the installation to the training of a model. The content is intended to explain the basic steps to be taken when creating a new -training pipeline based on Rente. +training pipeline based on Renate. If your goal is to test multiple Renate algorithms on your dataset or to test a specific algorithm on multiple datasets and scenarios, you will be better served by the benchmarking functionalities @@ -18,3 +18,4 @@ provided in :doc:`../benchmarking/renate_benchmarks`. how_to_run_training output supported_algorithms + avalanche diff --git a/doc/getting_started/supported_algorithms.rst b/doc/getting_started/supported_algorithms.rst index 3ec323b4..8fd3e197 100644 --- a/doc/getting_started/supported_algorithms.rst +++ b/doc/getting_started/supported_algorithms.rst @@ -3,7 +3,7 @@ Supported Algorithms Renate provides implementations of various continual learning methods. The following table provides an overview with links to the documentation, and a short description. When initiating model updates -using Renate (e.g., using :py:func:`renate.training.training.run_training_job`; see +using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`; see :doc:`how_to_run_training`), a method may be selected using the shorthand provided below. .. list-table:: Supported Algorithms @@ -36,3 +36,15 @@ using Renate (e.g., using :py:func:`renate.training.training.run_training_job`; * - ``"FineTuning"`` - :py:class:`Learner ` - A simple method which trains the current model on only the new data without any sort of mitigation for forgetting. Used as "lower bound" baseline in experiments. + * - ``"Avalanche-ER"`` + - :py:class:`AvalancheReplayLearner ` + - A wrapper which gives access to Experience Replay as implemented in the Avalanche library. This method is the equivalent to our Offline-ER. + * - ``"Avalanche-EWC"`` + - :py:class:`AvalancheEWCLearner ` + - A wrapper which gives access to Elastic Weight Consolidation as implemented in the Avalanche library. EWC updates the model in such a way that the parameters after the update remain close to the parameters before the update to avoid catastrophic forgetting. [`Paper `__] + * - ``"Avalanche-LwF"`` + - :py:class:`AvalancheLwFLearner ` + - A wrapper which gives access to Learning without Forgetting as implemented in the Avalanche library. LwF does not require to retain old data. It assumes that each new data chunk is its own task. A common backbone is shared across all task and each task has its own prediction head. [`Paper `__] + * - ``"Avalanche-iCaRL"`` + - :py:class:`AvalancheICaRLLearner ` + - A wrapper which gives access to iCaRL as implemented in the Avalanche library. This method is limited to class-incremental learning and combines knowledge distillation with nearest neighbors classification. [`Paper `__] From 349c0ad0f613a2bf22e39ab074dc063e3c6071f2 Mon Sep 17 00:00:00 2001 From: Giovanni <52964960+610v4nn1@users.noreply.github.com> Date: Tue, 25 Apr 2023 15:04:52 +0200 Subject: [PATCH 03/89] Update README.rst with paper reference (#205) --- README.rst | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/README.rst b/README.rst index 9ddbf51d..a707773c 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ :target: # :alt: PyPI - Status .. image:: https://img.shields.io/github/v/release/awslabs/Renate - :target: https://github.com/awslabs/Renate/releases/tag/v0.1.0 + :target: https://github.com/awslabs/Renate/releases/tag/v0.2.0 :alt: Latest Release .. image:: https://img.shields.io/pypi/dm/Renate :target: https://pypistats.org/packages/renate @@ -72,11 +72,29 @@ Key features * Advanced HPO functionalities available out-of-the-box * Open for experimentation -Blog posts -========== +Resources +========= -* `Automatically retrain neural networks with Renate `_ +* (blog) `Automatically retrain neural networks with Renate `_ +* (paper) `Renate: A Library for Real-World Continual Learning `_ +Cite Renate +=========== + + .. code-block:: bibtex + + @misc{renate2023, + title = {Renate: A Library for Real-World Continual Learning}, + author = {Martin Wistuba and + Martin Ferianc and + Lukas Balles and + Cedric Archambeau and + Giovanni Zappella}, + year = {2023}, + eprint = {2304.12067}, + archivePrefix = {arXiv}, + primaryClass = {cs.LG} + } What are you looking for? ========================= From 1f2c26b7e97782403030296826ddf696912a2784 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Wed, 3 May 2023 13:12:27 +0200 Subject: [PATCH 04/89] Add doc page explaining NLP example (#207) --- doc/examples/index.rst | 3 ++- doc/examples/nlp_finetuning.rst | 38 +++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 doc/examples/nlp_finetuning.rst diff --git a/doc/examples/index.rst b/doc/examples/index.rst index 4c197b62..ec04ca4d 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -6,4 +6,5 @@ Examples :maxdepth: 2 train_mlp_locally - train_classifier_sagemaker \ No newline at end of file + train_classifier_sagemaker + nlp_finetuning \ No newline at end of file diff --git a/doc/examples/nlp_finetuning.rst b/doc/examples/nlp_finetuning.rst new file mode 100644 index 00000000..de86b324 --- /dev/null +++ b/doc/examples/nlp_finetuning.rst @@ -0,0 +1,38 @@ +Working with NLP +**************** + +This example demonstrates how to use Renate to train NLP models. We will train a sequence classifier +to distinguish between positive and negative movie reviews. Using Renate, we will sequentially +train this model on two movie review datasets, called :code:`"imdb"` and :code:`"rotten_tomatoes"`. + +Configuration +============= + +Let us take a look at the :code:`renate_config.py` for this example. In the :code:`model_fn` +function, we use the Hugging Face :code:`transformers` library to instantiate a sequence +classification model. Since this model is static, we can easily turn it into a :code:`RenateModule` +by wrapping it in :py:class:`~renate.models.renate_module.RenateWrapper`. + +In the :code:`data_module_fn`, we load the matching tokenizer from the :code:`transformers` library. +We then use Renate's :py:class:`~renate.benchmark.datasets.nlp_datasets.HuggingfaceTextDataModule` +to access datasets from the `Hugging Face datasets hub `_. This +data module expects the name of a dataset as well as a tokenizer. Here, we load the :code:`"imdb"` +dataset in the first training stage (:code:`chunk_id = 0`) and the :code:`"rotten_tomatoes"` dataset +for the subsequent model update (:code:`chunk_id = 1`). + +The data module will return pre-tokenized data and no further transforms are needed in this case. + +.. literalinclude:: ../../examples/nlp_finetuning/renate_config.py + :lines: 3- + +Training +======== + +As in previous examples, we also include a launch script called :code:`start.py`. For more details +on this see previous examples or :doc:`../getting_started/how_to_run_training`. + +.. literalinclude:: ../../examples/nlp_finetuning/start.py + :lines: 3- + + + From 523099f687853a70a055784eb447aac2b3d72f6a Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 4 May 2023 17:39:30 +0200 Subject: [PATCH 05/89] Add NLP Components to Benchmarking (#213) --- doc/benchmarking/custom_benchmarks.rst | 2 +- doc/benchmarking/renate_benchmarks.rst | 8 ++ doc/getting_started/how_to_renate_config.rst | 18 ++-- doc/getting_started/output.rst | 10 +- examples/nlp_finetuning/renate_config.py | 13 +-- .../renate_config.py | 7 +- examples/train_mlp_locally/renate_config.py | 7 +- src/renate/benchmark/datasets/nlp_datasets.py | 16 ++- src/renate/benchmark/experiment_config.py | 50 +++++++-- src/renate/benchmark/models/transformer.py | 43 ++++++++ src/renate/defaults.py | 18 ++-- src/renate/memory/buffer.py | 2 +- src/renate/training/training.py | 2 +- .../configs/suites/quick/avalanche-icarl.json | 2 +- test/integration_tests/run_experiment.py | 2 +- test/integration_tests/run_quick_test.py | 2 +- .../benchmark/test_experimentation_config.py | 100 +++++++++++++++--- test/renate/data/test_data_module.py | 33 +++++- 18 files changed, 267 insertions(+), 68 deletions(-) create mode 100644 src/renate/benchmark/models/transformer.py diff --git a/doc/benchmarking/custom_benchmarks.rst b/doc/benchmarking/custom_benchmarks.rst index 6076333c..65aa2dbc 100644 --- a/doc/benchmarking/custom_benchmarks.rst +++ b/doc/benchmarking/custom_benchmarks.rst @@ -40,7 +40,7 @@ with your data module as follows. .. code-block:: python def data_module_fn( - data_path: Union[Path, str], chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]] + data_path: str, chunk_id: int, seed: int, class_groupings: Tuple[Tuple[int]] ): data_module = CustomDataModule(data_path=data_path, seed=seed) return ClassIncrementalScenario( diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 036e9cdf..3f128736 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -93,6 +93,10 @@ The full list of models and model names including a short description is provide * - `~renate.benchmark.models.vision_transformer.VisionTransformerH14` - Huge `Vision Transformer `_ architecture for images of size 224x224 with patch size 14. - * ``num_outputs``: Output dimensionality, for classification the number of classes. + * - `~renate.benchmark.models.transformer.HuggingFaceSequenceClassificationTransformer` + - Wrapper around Hugging Face transformers. + - * ``pretrained_model_name``: Hugging Face `transformer ID `__. + * ``num_outputs``: The number of classes. .. _benchmarking-renate-benchmarks-datasets: @@ -133,6 +137,10 @@ The following table contains the list of supported datasets. - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 - Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012. + * - hfd-{dataset_name} + - multiple + - Any `Hugging Face dataset `__ can be used. Just prepend the prefix ``hfd-``, e.g., ``hfd-rotten_tomatoes``. Select input and target columns via ``config_space``, e.g., add ``"input_column": "text", "target_column": "label"`` for the `rotten_tomatoes `__ example. + - Please refer to `the official documentation `__. .. _benchmarking-renate-benchmarks-scenarios: diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index e246d5fa..3320cb12 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -16,7 +16,7 @@ Its signature is .. code-block:: python - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: A :py:class:`~renate.models.renate_module.RenateModule` is a :code:`torch.nn.Module` with some additional functionality relevant to continual learning. @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 13-39 + :lines: 14-39 If you are using a torch model with **no or fixed hyperparameters**, you can use :py:class:`~renate.models.renate_module.RenateWrapper`. @@ -40,7 +40,7 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`. .. code-block:: python :caption: Example - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(28 * 28, 10) # Instantiate your torch model. model = RenateWrapper(my_torch_model) if model_state_url is not None: @@ -58,7 +58,7 @@ Its signature is .. code-block:: python - def data_module_fn(data_path: Union[Path, str], seed: int = defaults.SEED) -> RenateDataModule: + def data_module_fn(data_path: str, seed: int = defaults.SEED) -> RenateDataModule: :py:class:`~renate.data.data_module.RenateDataModule` provides a structured interface to download, set up, and access train/val/test datasets. @@ -67,7 +67,7 @@ such as data subsampling or splitting. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 42-72 + :lines: 43-69 Transforms ========== @@ -143,11 +143,11 @@ Let us assume we already have a config file in which we implemented a simple lin .. code-block:: python - def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(28 * 28, 10) model = RenateWrapper(my_torch_model) if model_state_url is not None: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model @@ -156,11 +156,11 @@ The natural change would be to change it to something like .. code-block:: python - def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: + def model_fn(num_inputs: int, num_outputs: int, model_state_url: Optional[str] = None) -> RenateModule: my_torch_model = torch.nn.Linear(num_inputs, num_outputs) model = RenateWrapper(my_torch_model) if model_state_url is not None: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model diff --git a/doc/getting_started/output.rst b/doc/getting_started/output.rst index 1c722bd3..08d36061 100644 --- a/doc/getting_started/output.rst +++ b/doc/getting_started/output.rst @@ -31,13 +31,13 @@ In the following, we refer with :code:`model_fn` to the function defined by the Output Saved Locally ~~~~~~~~~~~~~~~~~~~~ -If :code:`next_state_url` is a path to a local folder, loading the updated model can be done as follows: +If :code:`output_state_url` is a path to a local folder, loading the updated model can be done as follows: .. code-block:: python - from renate.defaults import current_state_folder, model_file + from renate.defaults import input_state_folder, model_file - my_model = model_fn(model_file(current_state_folder(next_state_url))) + my_model = model_fn(model_file(input_state_folder(output_state_url))) Output Saved on S3 ~~~~~~~~~~~~~~~~~~ @@ -46,9 +46,9 @@ If the Renate output was saved on S3, the model checkpoint :code:`model.ckpt` ca .. code-block:: python - from renate.defaults import current_state_folder, model_file + from renate.defaults import input_state_folder, model_file - print(model_file(current_state_folder(next_state_url))) + print(model_file(input_state_folder(output_state_url))) and then loaded via diff --git a/examples/nlp_finetuning/renate_config.py b/examples/nlp_finetuning/renate_config.py index d70ab679..9cdcd7ca 100644 --- a/examples/nlp_finetuning/renate_config.py +++ b/examples/nlp_finetuning/renate_config.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Optional, Union +from typing import Optional import torch import transformers @@ -13,26 +12,24 @@ from renate.models.renate_module import RenateWrapper -def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: """Returns a DistilBert classification model.""" transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=2, return_dict=False ) model = RenateWrapper(transformer_model, loss_fn=torch.nn.CrossEntropyLoss()) if model_state_url is not None: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model -def data_module_fn( - data_path: Union[Path, str], chunk_id: int, seed: int = defaults.SEED -) -> RenateDataModule: +def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule: """Returns one of two movie review datasets depending on `chunk_id`.""" tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased") dataset_name = "imdb" if chunk_id else "rotten_tomatoes" data_module = HuggingfaceTextDataModule( - str(data_path), + data_path, dataset_name=dataset_name, tokenizer=tokenizer, val_size=0.2, diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 6f70102b..563d46e7 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch from torchvision import transforms @@ -13,12 +12,12 @@ from renate.models import RenateModule -def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: """Returns a model instance.""" if model_state_url is None: model = ResNet18CIFAR() else: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model = ResNet18CIFAR.from_state_dict(state_dict) return model diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index 86cfdc94..779f4ef3 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from pathlib import Path -from typing import Callable, Optional, Union +from typing import Callable, Optional import torch from torchvision.transforms import transforms @@ -34,14 +33,14 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> return class_incremental_scenario -def model_fn(model_state_url: Optional[Union[Path, str]] = None) -> RenateModule: +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: """Returns a model instance.""" if model_state_url is None: model = MultiLayerPerceptron( num_inputs=784, num_outputs=10, num_hidden_layers=2, hidden_size=128 ) else: - state_dict = torch.load(str(model_state_url)) + state_dict = torch.load(model_state_url) model = MultiLayerPerceptron.from_state_dict(state_dict) return model diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index 0cd5eb96..70761bd7 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -6,6 +6,7 @@ import datasets import torch import transformers +from datasets import get_dataset_infos from renate import defaults from renate.data.data_module import RenateDataModule @@ -79,10 +80,21 @@ def __init__( def prepare_data(self) -> None: """Download data.""" split_names = datasets.get_dataset_split_names(self._dataset_name) - if not "train" in split_names: + if "train" not in split_names: raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'train' split.") - if not "test" in split_names: + if "test" not in split_names: raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'test' split.") + available_columns = list(get_dataset_infos(self._dataset_name)["default"].features) + if self._input_column not in available_columns: + raise ValueError( + f"Input column '{self._input_column}' does not exist in {self._dataset_name}. " + f"Available columns: {available_columns}." + ) + if self._target_column not in available_columns: + raise ValueError( + f"Target column '{self._target_column}' does not exist in {self._dataset_name}. " + f"Available columns: {available_columns}." + ) self._train_data = datasets.load_dataset( self._dataset_name, split="train", cache_dir=self._data_path ) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 68023407..24a12b6b 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -4,7 +4,9 @@ import torch from torchvision.transforms import transforms +from transformers import AutoTokenizer +from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule from renate.benchmark.models import ( MultiLayerPerceptron, @@ -21,6 +23,7 @@ VisionTransformerL16, VisionTransformerL32, ) +from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer from renate.benchmark.scenarios import ( BenchmarkScenario, ClassIncrementalScenario, @@ -49,6 +52,7 @@ "VisionTransformerL16": VisionTransformerL16, "VisionTransformerL32": VisionTransformerL32, "VisionTransformerH14": VisionTransformerH14, + "HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer, } @@ -60,6 +64,7 @@ def model_fn( num_outputs: Optional[int] = None, num_hidden_layers: Optional[int] = None, hidden_size: Optional[Tuple[int]] = None, + pretrained_model_name: Optional[str] = None, ) -> RenateModule: """Returns a model instance.""" if model_name not in models: @@ -69,11 +74,17 @@ def model_fn( if updater == "Avalanche-iCaRL": model_kwargs["prediction_strategy"] = ICaRLClassificationStrategy() if model_name == "MultiLayerPerceptron": - model_kwargs = { - "num_inputs": num_inputs, - "num_hidden_layers": num_hidden_layers, - "hidden_size": hidden_size, - } + model_kwargs.update( + { + "num_inputs": num_inputs, + "num_hidden_layers": num_hidden_layers, + "hidden_size": hidden_size, + } + ) + elif model_name == "HuggingFaceTransformer": + if updater == "Avalanche-iCaRL": + raise ValueError("Transformers do not support iCaRL.") + model_kwargs["pretrained_model_name"] = pretrained_model_name if num_outputs is not None: model_kwargs["num_outputs"] = num_outputs if model_state_url is None: @@ -85,7 +96,13 @@ def model_fn( def get_data_module( - data_path: str, dataset_name: str, val_size: float, seed: int + data_path: str, + dataset_name: str, + val_size: float, + seed: int, + pretrained_model_name: Optional[str], + input_column: Optional[str], + target_column: Optional[str], ) -> RenateDataModule: if dataset_name in TorchVisionDataModule.dataset_dict: return TorchVisionDataModule( @@ -93,6 +110,17 @@ def get_data_module( ) if dataset_name in ["CLEAR10", "CLEAR100"]: return CLEARDataModule(data_path, dataset_name=dataset_name, val_size=val_size, seed=seed) + if dataset_name.startswith("hfd-"): + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + return HuggingfaceTextDataModule( + data_path=data_path, + dataset_name=dataset_name[4:], + input_column=input_column, + target_column=target_column, + tokenizer=tokenizer, + val_size=val_size, + seed=seed, + ) raise ValueError(f"Unknown dataset `{dataset_name}`.") @@ -191,12 +219,18 @@ def data_module_fn( input_dim: Optional[Tuple[int]] = None, feature_idx: Optional[int] = None, randomness: Optional[float] = None, + pretrained_model_name: Optional[str] = None, + input_column: Optional[str] = None, + target_column: Optional[str] = None, ): data_module = get_data_module( data_path=data_path, dataset_name=dataset_name, val_size=val_size, seed=seed, + pretrained_model_name=pretrained_model_name, + input_column=input_column, + target_column=target_column, ) return get_scenario( scenario_name=scenario_name, @@ -222,7 +256,7 @@ def _get_normalize_transform(dataset_name): def train_transform(dataset_name: str) -> Optional[transforms.Compose]: """Returns a transform function to be used in the training.""" - if dataset_name in ["MNIST", "FashionMNIST"]: + if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"): return None elif dataset_name in ["CIFAR10", "CIFAR100"]: return transforms.Compose( @@ -237,7 +271,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]: def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: """Returns a transform function to be used for validation or testing.""" - if dataset_name in ["MNIST", "FashionMNIST"]: + if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"): return None elif dataset_name in ["CIFAR10", "CIFAR100"]: return _get_normalize_transform(dataset_name) diff --git a/src/renate/benchmark/models/transformer.py b/src/renate/benchmark/models/transformer.py new file mode 100644 index 00000000..fe8909d4 --- /dev/null +++ b/src/renate/benchmark/models/transformer.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch import Tensor +from transformers import AutoModelForSequenceClassification + +from renate.models import RenateModule + + +class HuggingFaceSequenceClassificationTransformer(RenateModule): + """RenateModule which wraps around Hugging Face transformers. + + Args: + pretrained_model_name: Hugging Face model id. + num_outputs: Number of outputs. + loss_fn: The loss function to be optimized during the training. + """ + + def __init__( + self, + pretrained_model_name: str, + num_outputs: int, + loss_fn: nn.Module = nn.CrossEntropyLoss(), + ) -> None: + super().__init__( + constructor_arguments={ + "pretrained_model_name": pretrained_model_name, + "num_outputs": num_outputs, + }, + loss_fn=loss_fn, + ) + self._model = AutoModelForSequenceClassification.from_pretrained( + pretrained_model_name, num_labels=num_outputs, return_dict=False + ) + + def forward(self, x: Dict[str, Tensor], task_id: Optional[str] = None) -> torch.Tensor: + return self._model(**x)[0] + + def _add_task_params(self, task_id: str) -> None: + assert not len(self._tasks_params_ids), "Transformer does not work for multiple tasks." diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 20fd8a84..3131142d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -120,23 +120,23 @@ def current_timestamp() -> str: return str(datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S-%f")) -def data_folder(working_directory: str): +def data_folder(working_directory: str) -> str: return os.path.join(working_directory, "data") -def input_state_folder(working_directory: str): +def input_state_folder(working_directory: str) -> str: return os.path.join(working_directory, "input_state") -def output_state_folder(working_directory: str): +def output_state_folder(working_directory: str) -> str: return os.path.join(working_directory, "output_state") -def logs_folder(working_directory: str): +def logs_folder(working_directory: str) -> str: return os.path.join(working_directory, "logs") -def model_file(state_folder: str): +def model_file(state_folder: str) -> str: return os.path.join(state_folder, "model.ckpt") @@ -144,17 +144,17 @@ def model_file(state_folder: str): AVALANCHE_CHECKPOINT_NAME = "avalanche.ckpt" -def learner_state_file(state_folder: str): +def learner_state_file(state_folder: str) -> str: return os.path.join(state_folder, LEARNER_CHECKPOINT_NAME) -def avalanche_state_file(state_folder: str): +def avalanche_state_file(state_folder: str) -> str: return os.path.join(state_folder, AVALANCHE_CHECKPOINT_NAME) -def metric_summary_file(logs_folder: str, special_str: str = ""): +def metric_summary_file(logs_folder: str, special_str: str = "") -> str: return os.path.join(logs_folder, f"metrics_summary{special_str}.csv") -def hpo_file(state_folder: str): +def hpo_file(state_folder: str) -> str: return os.path.join(state_folder, "hpo.csv") diff --git a/src/renate/memory/buffer.py b/src/renate/memory/buffer.py index 8ee63130..3849c9ab 100644 --- a/src/renate/memory/buffer.py +++ b/src/renate/memory/buffer.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict import copy +from collections import defaultdict from typing import Callable, Dict, Optional import torch diff --git a/src/renate/training/training.py b/src/renate/training/training.py index c1338471..0b33edc3 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -31,7 +31,7 @@ import renate from renate import defaults -from renate.cli.parsing_functions import to_dense_str, get_data_module_fn_kwargs +from renate.cli.parsing_functions import get_data_module_fn_kwargs, to_dense_str from renate.utils.file import move_to_uri from renate.utils.module import get_and_prepare_data_module, import_module from renate.utils.syne_tune import ( diff --git a/test/integration_tests/configs/suites/quick/avalanche-icarl.json b/test/integration_tests/configs/suites/quick/avalanche-icarl.json index 456194da..3e5ba545 100644 --- a/test/integration_tests/configs/suites/quick/avalanche-icarl.json +++ b/test/integration_tests/configs/suites/quick/avalanche-icarl.json @@ -5,5 +5,5 @@ "dataset": "mnist.json", "backend": "local", "job_name": "class-incremental-mlp-avalanche-icarl", - "expected_accuracy": [0.9981087446212769, 0.5656219124794006] + "expected_accuracy": [0.993380606174469, 0.8330068588256836] } diff --git a/test/integration_tests/run_experiment.py b/test/integration_tests/run_experiment.py index 50dc22d3..ed0c76be 100644 --- a/test/integration_tests/run_experiment.py +++ b/test/integration_tests/run_experiment.py @@ -57,7 +57,7 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): f"--max-time", type=int, default=12 * 3600, - help="Seed.", + help="Maximum execution time.", ) args = parser.parse_args() config_space = load_config( diff --git a/test/integration_tests/run_quick_test.py b/test/integration_tests/run_quick_test.py index f2208e41..f3b5ca79 100644 --- a/test/integration_tests/run_quick_test.py +++ b/test/integration_tests/run_quick_test.py @@ -67,4 +67,4 @@ # Noticed different accuracy scores across Mac and GitHub Actions Workflows (which run on Linux) # TODO see if we can align the Mac and Linux results - assert pytest.approx(test_config["expected_accuracy"]) == accuracies + assert pytest.approx(test_config["expected_accuracy"]) == accuracies, accuracies diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 3682d35d..9b71258e 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -4,6 +4,7 @@ from torchvision.transforms import Compose, Normalize from renate.benchmark import experiment_config +from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule from renate.benchmark.experiment_config import ( data_module_fn, @@ -22,20 +23,24 @@ ImageRotationScenario, PermutationScenario, ) +from renate.models.prediction_strategies import ICaRLClassificationStrategy @pytest.mark.parametrize( "model_name,expected_model_class", - [(model_name, model_class) for model_name, model_class in zip(models.keys(), models.values())], + [(model_name, model_class) for model_name, model_class in models.items()], ) def test_model_fn(model_name, expected_model_class): model = model_fn( model_state_url=None, model_name=model_name, num_inputs=1 if model_name == "MultiLayerPerceptron" else None, - num_outputs=1 if model_name == "MultiLayerPerceptron" else None, + num_outputs=1 if model_name in ["MultiLayerPerceptron", "HuggingFaceTransformer"] else None, num_hidden_layers=1 if model_name == "MultiLayerPerceptron" else None, hidden_size=1 if model_name == "MultiLayerPerceptron" else None, + pretrained_model_name="distilbert-base-uncased" + if model_name == "HuggingFaceTransformer" + else None, ) assert isinstance(model, expected_model_class) @@ -47,22 +52,58 @@ def test_model_fn_fails_for_unknown_model(): @pytest.mark.parametrize( - "dataset_name,data_module_class", - (("CIFAR10", TorchVisionDataModule), ("CLEAR10", CLEARDataModule)), + "dataset_name,data_module_class,pretrained_model_name,input_column,target_column", + ( + ("CIFAR10", TorchVisionDataModule, None, None, None), + ("CLEAR10", CLEARDataModule, None, None, None), + ( + "hfd-rotten-tomatoes", + HuggingfaceTextDataModule, + "distilbert-base-uncased", + "text", + "label", + ), + ), ) -def test_get_data_module(tmpdir, dataset_name, data_module_class): - data_module = get_data_module(data_path=tmpdir, dataset_name=dataset_name, val_size=0.5, seed=0) +def test_get_data_module( + tmpdir, dataset_name, data_module_class, pretrained_model_name, input_column, target_column +): + data_module = get_data_module( + data_path=tmpdir, + dataset_name=dataset_name, + val_size=0.5, + seed=0, + pretrained_model_name=pretrained_model_name, + input_column=input_column, + target_column=target_column, + ) assert isinstance(data_module, data_module_class) def test_get_data_module_fails_for_unknown_dataset(tmpdir): unknown_dataset_name = "UNKNOWN_DATASET_NAME" with pytest.raises(ValueError, match=f"Unknown dataset `{unknown_dataset_name}`"): - get_data_module(data_path=tmpdir, dataset_name=unknown_dataset_name, val_size=0.5, seed=0) + get_data_module( + data_path=tmpdir, + dataset_name=unknown_dataset_name, + val_size=0.5, + seed=0, + pretrained_model_name=None, + input_column=None, + target_column=None, + ) def test_get_scenario_fails_for_unknown_scenario(tmpdir): - data_module = get_data_module(data_path=tmpdir, dataset_name="MNIST", val_size=0.5, seed=0) + data_module = get_data_module( + data_path=tmpdir, + dataset_name="MNIST", + val_size=0.5, + seed=0, + pretrained_model_name=None, + input_column=None, + target_column=None, + ) unknown_scenario_name = "UNKNOWN_SCENARIO_NAME" with pytest.raises(ValueError, match=f"Unknown scenario `{unknown_scenario_name}`"): get_scenario( @@ -75,8 +116,13 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): ( ( "ClassIncrementalScenario", - "CIFAR10", - {"class_groupings": ((0, 1), (2, 3, 4), (5, 6))}, + "hfd-trec", + { + "pretrained_model_name": "distilbert-base-uncased", + "input_column": "text", + "target_column": "coarse_label", + "class_groupings": ((0, 1), (2, 3), (4, 5)), + }, ClassIncrementalScenario, 3, ), @@ -122,7 +168,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): ), ), ids=[ - "class_incremental_image", + "class_incremental", "iid", "rotation", "benchmark", @@ -163,7 +209,13 @@ def test_data_module_fn( @pytest.mark.parametrize( "dataset_name,use_transforms", - (("MNIST", False), ("FashionMNIST", False), ("CIFAR10", True), ("CIFAR100", True)), + ( + ("MNIST", False), + ("FashionMNIST", False), + ("CIFAR10", True), + ("CIFAR100", True), + ("hfd-rotten_tomatoes", False), + ), ) def test_transforms(dataset_name, use_transforms): train_preprocessing = train_transform(dataset_name) @@ -181,3 +233,27 @@ def test_transforms_fails_for_unknown_dataset(): for transform_function in [train_transform, experiment_config.test_transform]: with pytest.raises(ValueError, match=f"Unknown dataset `{unknown_dataset_set}`"): transform_function(unknown_dataset_set) + + +@pytest.mark.parametrize("model_name", [model_name for model_name in models]) +@pytest.mark.parametrize("updater", ("ER", "Avalanche-iCaRL")) +def test_prediction_strategy_is_correctly_set(model_name, updater): + """If iCaRL is used, the model prediction strategy must be changed.""" + model_kwargs = { + "model_name": model_name, + "updater": updater, + "num_outputs": 2, + } + if model_name == "MultiLayerPerceptron": + model_kwargs.update({"num_inputs": 10, "hidden_size": 10, "num_hidden_layers": 2}) + elif model_name == "HuggingFaceTransformer": + model_kwargs["pretrained_model_name"] = "distilbert-base-uncased" + if model_name == "HuggingFaceTransformer" and updater == "Avalanche-iCaRL": + with pytest.raises(ValueError, match="Transformers do not support iCaRL."): + model_fn(**model_kwargs) + else: + model = model_fn(**model_kwargs) + if updater == "ER": + assert not hasattr(model, "_prediction_strategy") or model._prediction_strategy is None + else: + assert isinstance(model._prediction_strategy, ICaRLClassificationStrategy) diff --git a/test/renate/data/test_data_module.py b/test/renate/data/test_data_module.py index 303a33bc..0f83097d 100644 --- a/test/renate/data/test_data_module.py +++ b/test/renate/data/test_data_module.py @@ -150,7 +150,7 @@ def test_tiny_imagenet_data_module(tmpdir): dict(padding="longest", max_length=512, truncation=True), ], ) -def test_huggingface_data_module( +def test_hugging_face_data_module( tmpdir, dataset_name, input_column, target_column, tokenizer, tokenizer_kwargs ): data_module = HuggingfaceTextDataModule( @@ -177,3 +177,34 @@ def test_huggingface_data_module( assert set(inputs.keys()) == set(["input_ids", "attention_mask"]) assert isinstance(inputs["input_ids"], torch.Tensor) assert isinstance(inputs["attention_mask"], torch.Tensor) + + +@pytest.mark.parametrize("column", ("input", "target"), ids=("input", "target")) +def test_hugging_face_exception_raised_with_wrong_column(tmpdir, column): + input_column = "text" + target_column = "label" + if column == "input": + input_column = "WRONG_COLUMN" + elif column == "target": + target_column = "WRONG_COLUMN" + data_module = HuggingfaceTextDataModule( + data_path=tmpdir, + dataset_name="rotten_tomatoes", + input_column=input_column, + target_column=target_column, + val_size=0.2, + ) + if column == "input": + with pytest.raises( + ValueError, + match="Input column 'WRONG_COLUMN' does not exist in rotten_tomatoes. " + "Available columns: \\['text', 'label'\\].", + ): + data_module.prepare_data() + elif column == "target": + with pytest.raises( + ValueError, + match="Target column 'WRONG_COLUMN' does not exist in rotten_tomatoes. " + "Available columns: \\['text', 'label'\\].", + ): + data_module.prepare_data() From d835baa355d16a46c471e8a30d137b5118212f85 Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 5 May 2023 14:43:02 +0200 Subject: [PATCH 06/89] Robust Integration Tests (#214) --- .github/workflows/run_renate.yml | 2 +- doc/getting_started/install.rst | 2 +- src/renate/benchmark/experimentation.py | 2 +- src/renate/cli/run_training.py | 2 +- .../configs/suites/quick/avalanche-er.json | 3 ++- .../configs/suites/quick/avalanche-ewc.json | 3 ++- .../configs/suites/quick/avalanche-icarl.json | 3 ++- .../configs/suites/quick/avalanche-lwf.json | 3 ++- test/integration_tests/configs/suites/quick/cls-er.json | 3 ++- test/integration_tests/configs/suites/quick/der.json | 3 ++- test/integration_tests/configs/suites/quick/er.json | 3 ++- .../configs/suites/quick/fine-tuning.json | 4 +++- test/integration_tests/configs/suites/quick/gdumb.json | 3 ++- test/integration_tests/configs/suites/quick/joint.json | 3 ++- .../configs/suites/quick/offline-er.json | 4 ++-- test/integration_tests/configs/suites/quick/pod-er.json | 3 ++- .../integration_tests/configs/suites/quick/super-er.json | 3 ++- .../configs/updaters/offline-er-buffer500.json | 3 ++- test/integration_tests/run_quick_test.py | 9 ++++----- 19 files changed, 37 insertions(+), 24 deletions(-) diff --git a/.github/workflows/run_renate.yml b/.github/workflows/run_renate.yml index d8fcfab0..59f1dd5d 100644 --- a/.github/workflows/run_renate.yml +++ b/.github/workflows/run_renate.yml @@ -30,7 +30,7 @@ env: jobs: run: - runs-on: Renate_ubuntu-latest_16-core + runs-on: ubuntu-latest timeout-minutes: ${{ inputs.timeout-minutes }} steps: - name: Configure AWS Credentials (if required) diff --git a/doc/getting_started/install.rst b/doc/getting_started/install.rst index 4c4a3485..32e264d7 100644 --- a/doc/getting_started/install.rst +++ b/doc/getting_started/install.rst @@ -7,7 +7,7 @@ Renate is available via PyPI and can be installed using :code:`pip`: pip install Renate -If you want to use additional methods that require the Avalanche library, please use use +If you want to use additional methods that require the Avalanche library, please use .. code-block:: bash diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index d30eb366..b1075d79 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -252,7 +252,7 @@ def _execute_experiment_job_locally( See renate.benchmark.experimentation.execute_experiment_job for more details. """ logger.info("Start experiment.") - seed_everything(seed) + seed_everything(seed, True) input_state_url = defaults.input_state_folder(working_directory) output_state_url = defaults.output_state_folder(working_directory) diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index 46a5deea..9ba9484f 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -99,7 +99,7 @@ def run(self): ignore_args=["data_path", "model_state_url"], ) - seed_everything(args.seed) + seed_everything(args.seed, True) self._prepare_data_state_model(args) data_module = get_and_setup_data_module( diff --git a/test/integration_tests/configs/suites/quick/avalanche-er.json b/test/integration_tests/configs/suites/quick/avalanche-er.json index 769ce21a..5badc4df 100644 --- a/test/integration_tests/configs/suites/quick/avalanche-er.json +++ b/test/integration_tests/configs/suites/quick/avalanche-er.json @@ -5,5 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-avalanche-er", - "expected_accuracy": [0.14100000262260437, 0.6794999837875366] + "expected_accuracy_linux": [[0.14100000262260437, 0.6794999837875366], [0.41449999809265137, 0.5740000009536743]], + "expected_accuracy_darwin": [[0.07649999856948853, 0.7114999890327454]] } diff --git a/test/integration_tests/configs/suites/quick/avalanche-ewc.json b/test/integration_tests/configs/suites/quick/avalanche-ewc.json index ee02aeb2..63f762d8 100644 --- a/test/integration_tests/configs/suites/quick/avalanche-ewc.json +++ b/test/integration_tests/configs/suites/quick/avalanche-ewc.json @@ -5,5 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "rotation-mlp-avalanche-ewc", - "expected_accuracy": [0.7580999732017517, 0.9627000093460083] + "expected_accuracy_linux": [[0.7580999732017517, 0.9627000093460083], [0.7551000118255615, 0.9664999842643738]], + "expected_accuracy_darwin": [[0.7497000098228455, 0.9664999842643738]] } diff --git a/test/integration_tests/configs/suites/quick/avalanche-icarl.json b/test/integration_tests/configs/suites/quick/avalanche-icarl.json index 3e5ba545..a0991a93 100644 --- a/test/integration_tests/configs/suites/quick/avalanche-icarl.json +++ b/test/integration_tests/configs/suites/quick/avalanche-icarl.json @@ -5,5 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "class-incremental-mlp-avalanche-icarl", - "expected_accuracy": [0.993380606174469, 0.8330068588256836] + "expected_accuracy_linux": [[0.993380606174469, 0.8330068588256836], [0.9947990775108337, 0.8222330808639526]], + "expected_accuracy_darwin": [[0.9947990775108337, 0.7845249772071838]] } diff --git a/test/integration_tests/configs/suites/quick/avalanche-lwf.json b/test/integration_tests/configs/suites/quick/avalanche-lwf.json index ab54a256..cab2ab47 100644 --- a/test/integration_tests/configs/suites/quick/avalanche-lwf.json +++ b/test/integration_tests/configs/suites/quick/avalanche-lwf.json @@ -5,5 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "permutation-mlp-avalanche-lwf", - "expected_accuracy": [0.7526999711990356, 0.9607999920845032] + "expected_accuracy_linux": [[0.7526999711990356, 0.9607999920845032], [0.6541000008583069, 0.9617999792098999]], + "expected_accuracy_darwin": [[0.7202000021934509, 0.9646999835968018]] } diff --git a/test/integration_tests/configs/suites/quick/cls-er.json b/test/integration_tests/configs/suites/quick/cls-er.json index f32e67a2..f7882609 100644 --- a/test/integration_tests/configs/suites/quick/cls-er.json +++ b/test/integration_tests/configs/suites/quick/cls-er.json @@ -5,5 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "class-incremental-mlp-cls-er", - "expected_accuracy": [0.9858155846595764, 0.9740450382232666] + "expected_accuracy_linux": [[0.9858155846595764, 0.9740450382232666]], + "expected_accuracy_darwin": [[0.9858155846595764, 0.9755141735076904]] } diff --git a/test/integration_tests/configs/suites/quick/der.json b/test/integration_tests/configs/suites/quick/der.json index 1c68061d..942c1743 100644 --- a/test/integration_tests/configs/suites/quick/der.json +++ b/test/integration_tests/configs/suites/quick/der.json @@ -5,5 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "feature-sorting-mlp-der", - "expected_accuracy": [0.35339999198913574, 0.3086000084877014] + "expected_accuracy_linux": [[0.35339999198913574, 0.3086000084877014], [0.35179999470710754, 0.3208000063896179]], + "expected_accuracy_darwin": [[0.3400000035762787, 0.3253999948501587]] } diff --git a/test/integration_tests/configs/suites/quick/er.json b/test/integration_tests/configs/suites/quick/er.json index fa14706c..4abe300f 100644 --- a/test/integration_tests/configs/suites/quick/er.json +++ b/test/integration_tests/configs/suites/quick/er.json @@ -5,5 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-er", - "expected_accuracy": [0.5799999833106995, 0.367000013589859] + "expected_accuracy_linux": [[0.5799999833106995, 0.367000013589859], [0.6100000143051147, 0.5975000262260437]], + "expected_accuracy_darwin": [[0.49050000309944153, 0.671500027179718]] } diff --git a/test/integration_tests/configs/suites/quick/fine-tuning.json b/test/integration_tests/configs/suites/quick/fine-tuning.json index 9912fc9b..003d81a7 100644 --- a/test/integration_tests/configs/suites/quick/fine-tuning.json +++ b/test/integration_tests/configs/suites/quick/fine-tuning.json @@ -5,5 +5,7 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "iid-mlp-fine-tuning", - "expected_accuracy": [0.8458999991416931, 0.8458999991416931] + "expected_accuracy_linux": [[0.8458999991416931, 0.8458999991416931], [0.8574000000953674, 0.8574000000953674]], + "expected_accuracy_darwin": [[0.8521000146865845, 0.8521000146865845]] } + diff --git a/test/integration_tests/configs/suites/quick/gdumb.json b/test/integration_tests/configs/suites/quick/gdumb.json index a8919d1d..04631cd6 100644 --- a/test/integration_tests/configs/suites/quick/gdumb.json +++ b/test/integration_tests/configs/suites/quick/gdumb.json @@ -5,5 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-gdumb", - "expected_accuracy": [0.8364999890327454, 0.8634999990463257] + "expected_accuracy_linux": [[0.8364999890327454, 0.8634999990463257]], + "expected_accuracy_darwin": [[0.8364999890327454, 0.8634999990463257]] } diff --git a/test/integration_tests/configs/suites/quick/joint.json b/test/integration_tests/configs/suites/quick/joint.json index 533dc31b..8cb887cd 100644 --- a/test/integration_tests/configs/suites/quick/joint.json +++ b/test/integration_tests/configs/suites/quick/joint.json @@ -5,5 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "iid-mlp-joint", - "expected_accuracy": [0.8639000058174133, 0.8639000058174133] + "expected_accuracy_linux": [[0.8639000058174133, 0.8639000058174133], [0.8618000149726868, 0.8618000149726868]], + "expected_accuracy_darwin": [[0.859499990940094, 0.859499990940094]] } diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 8d25f6f7..6673d49d 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy": [0.7089999914169312, 0.5120000243186951], - "loss_weight_new_data": 0.5 + "expected_accuracy_linux": [[0.7039999961853027, 0.4569999873638153], [0.6664999723434448, 0.5544999837875366]], + "expected_accuracy_darwin": [[0.6965000033378601, 0.4284999966621399]] } diff --git a/test/integration_tests/configs/suites/quick/pod-er.json b/test/integration_tests/configs/suites/quick/pod-er.json index 4e452463..915acb6c 100644 --- a/test/integration_tests/configs/suites/quick/pod-er.json +++ b/test/integration_tests/configs/suites/quick/pod-er.json @@ -5,5 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "hue-shift-mlp-pod-er", - "expected_accuracy": [0.17059999704360962, 0.31299999356269836] + "expected_accuracy_linux": [[0.17059999704360962, 0.31299999356269836], [0.19140000641345978, 0.3246000111103058]], + "expected_accuracy_darwin": [[0.20000000298023224, 0.2637999951839447]] } diff --git a/test/integration_tests/configs/suites/quick/super-er.json b/test/integration_tests/configs/suites/quick/super-er.json index 13d36c4c..4fe2a4e0 100644 --- a/test/integration_tests/configs/suites/quick/super-er.json +++ b/test/integration_tests/configs/suites/quick/super-er.json @@ -5,5 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-super-er", - "expected_accuracy": [0.9330000281333923, 0.9144999980926514] + "expected_accuracy_linux": [[0.9330000281333923, 0.9144999980926514], [0.9434999823570251, 0.8974999785423279]], + "expected_accuracy_darwin": [[0.9390000104904175, 0.9089999794960022]] } diff --git a/test/integration_tests/configs/updaters/offline-er-buffer500.json b/test/integration_tests/configs/updaters/offline-er-buffer500.json index 189d3e6e..eaa22099 100644 --- a/test/integration_tests/configs/updaters/offline-er-buffer500.json +++ b/test/integration_tests/configs/updaters/offline-er-buffer500.json @@ -6,5 +6,6 @@ "weight_decay": 0.0, "batch_size": 256, "memory_batch_size": 256, - "memory_size": 500 + "memory_size": 500, + "loss_weight_new_data": 0.5 } diff --git a/test/integration_tests/run_quick_test.py b/test/integration_tests/run_quick_test.py index f3b5ca79..44029384 100644 --- a/test/integration_tests/run_quick_test.py +++ b/test/integration_tests/run_quick_test.py @@ -5,6 +5,7 @@ import os import subprocess from pathlib import Path +from sys import platform import pandas as pd import pytest @@ -50,7 +51,8 @@ ] ) process.wait() - num_updates = len(test_config["expected_accuracy"]) + expected_accuracy = test_config[f"expected_accuracy_{platform}"] + num_updates = len(test_config["expected_accuracy_darwin"][0]) result_file = ( Path("tmp") / "renate-integration-tests" @@ -64,7 +66,4 @@ accuracies = [float(acc) for acc in list(df.iloc[-1])[1:]] else: accuracies = [] - - # Noticed different accuracy scores across Mac and GitHub Actions Workflows (which run on Linux) - # TODO see if we can align the Mac and Linux results - assert pytest.approx(test_config["expected_accuracy"]) == accuracies, accuracies + assert any([pytest.approx(acc) == accuracies for acc in expected_accuracy]), accuracies From d366fe9627b0a04c5076fd4f2d392b06f583c6a4 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 10 May 2023 17:16:39 +0200 Subject: [PATCH 07/89] Update Renate Config Example (#226) --- doc/getting_started/how_to_renate_config.rst | 8 ++--- examples/getting_started/renate_config.py | 33 +++++++++----------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index 3320cb12..a7a08671 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 14-39 + :lines: 14-40 If you are using a torch model with **no or fixed hyperparameters**, you can use :py:class:`~renate.models.renate_module.RenateWrapper`. @@ -67,7 +67,7 @@ such as data subsampling or splitting. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 43-69 + :lines: 43-66 Transforms ========== @@ -112,7 +112,7 @@ These are optional as well but, if omitted, Renate will use :code:`train_transfo .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 76-93 + :lines: 73-90 Custom Metrics ============== @@ -124,7 +124,7 @@ or created ad-hoc by implementing the same interface .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 96- + :lines: 93- To enable the usage of additional metrics in Renate it is sufficient to implement the :code:`metrics_fn` function, returning a dictionary where the key is a string containing the diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index f99f1670..ddc2c8e8 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -1,6 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Literal, Optional +from typing import Callable, Dict, Optional import torch import torchvision @@ -49,24 +49,21 @@ def prepare_data(self) -> None: # streamline data loading when using multiple training jobs during HPO. torchvision.datasets.MNIST(self._data_path, download=True) - def setup(self, stage: Optional[Literal["train", "val", "test"]] = None) -> None: + def setup(self) -> None: # This sets up train/val/test datasets, assuming data has already been downloaded. - if stage in ["train", "val"] or stage is None: - train_data = torchvision.datasets.MNIST( - self._data_path, - train=True, - transform=transforms.ToTensor(), - target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), - ) - self._train_data, self._val_data = self._split_train_val_data(train_data) - - if stage == "test" or stage is None: - self._test_data = torchvision.datasets.MNIST( - self._data_path, - train=False, - transform=transforms.ToTensor(), - target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), - ) + train_data = torchvision.datasets.MNIST( + self._data_path, + train=True, + transform=transforms.ToTensor(), + target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), + ) + self._train_data, self._val_data = self._split_train_val_data(train_data) + self._test_data = torchvision.datasets.MNIST( + self._data_path, + train=False, + transform=transforms.ToTensor(), + target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), + ) def data_module_fn(data_path: str, seed: int) -> RenateDataModule: From 663ca4d48865181d553ce217d2068f78ac8bbd9c Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 15 May 2023 18:59:44 +0200 Subject: [PATCH 08/89] Make Wild Time Available in Benchmarking (#187) --- doc/benchmarking/renate_benchmarks.rst | 21 +++++ doc/requirements.txt | 1 + examples/nlp_finetuning/renate_config.py | 4 +- pyproject.toml | 4 + src/renate/benchmark/datasets/nlp_datasets.py | 25 ++--- .../benchmark/datasets/wild_time_data.py | 93 +++++++++++++++++++ src/renate/benchmark/experiment_config.py | 49 ++++++++-- src/renate/benchmark/experimentation.py | 5 +- src/renate/benchmark/models/resnet.py | 12 +++ src/renate/benchmark/scenarios.py | 38 ++++++++ src/renate/defaults.py | 2 +- src/renate/training/training.py | 17 +++- test/renate/benchmark/models/test_resnet.py | 8 ++ .../benchmark/test_experimentation_config.py | 59 +++++++++++- test/renate/data/test_data_module.py | 7 +- 15 files changed, 312 insertions(+), 33 deletions(-) create mode 100644 src/renate/benchmark/datasets/wild_time_data.py diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 3f128736..402ad19a 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -121,6 +121,10 @@ The following table contains the list of supported datasets. - Task - Data Summary - Reference + * - arxiv + - Text Classification: category recognition of arXiv papers. + - ~1.9M train, ~206k test, 172 classes, years 2007-2023 + - Huaxiu Yao et al.: Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time. Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2022. * - CIFAR10 - Image Classification - 50k train, 10k test, 10 classes, image shape 32x32x3 @@ -133,10 +137,22 @@ The following table contains the list of supported datasets. - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 - Han Xiao et al.: Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms. 2017. + * - fmow + - Image Classification: land use recognition from satellite images. + - 62 classes, image shape 32x32x3 + - Huaxiu Yao et al.: Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time. Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2022. + * - huffpost + - Text Classification: category recognition of news paper articles. + - ~58k train, ~6k test, 11 classes, years 2012-2019 + - Huaxiu Yao et al.: Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time. Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2022. * - MNIST - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 - Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012. + * - yearbook + - Image Classification: gender identification in yearbook photos. + - ~33k train, ~4k test, 2 classes, years 1930-2013, image shape 32x32x1 + - Huaxiu Yao et al.: Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time. Conference on Neural Information Processing Systems Datasets and Benchmarks Track. 2022. * - hfd-{dataset_name} - multiple - Any `Hugging Face dataset `__ can be used. Just prepend the prefix ``hfd-``, e.g., ``hfd-rotten_tomatoes``. Select input and target columns via ``config_space``, e.g., add ``"input_column": "text", "target_column": "label"`` for the `rotten_tomatoes `__ example. @@ -171,6 +187,11 @@ The first part contains all instances with classes 1 and 2, the second with clas * - :py:class:`~renate.benchmark.scenarios.BenchmarkScenario` - Used in combination only with CLEAR-10 or CLEAR-100. - * :code:`num_tasks`: Number of data partitions. + * - :py:class:`~renate.benchmark.scenarios.WildTimeScenario` + - Used in combination only with Wild-Time datasets. This is not the scenario used in the paper. + Data is presented time step by time step and the model is evaluated on test data up to the + current time step. + - * :code:`num_tasks`: Number of data partitions. * - :py:class:`~renate.benchmark.scenarios.ClassIncrementalScenario` - Creates data partitions by splitting the data according to class labels. - * :code:`class_groupings`: Tuple of tuples containing the class labels, e.g., ``((1, ), (2, 3, 4))``. diff --git a/doc/requirements.txt b/doc/requirements.txt index 3f18a3ad..6cf5cc43 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -9,3 +9,4 @@ sphinx-paramlinks==0.5.4 # Temporarily added avalanche_lib==0.3.1 +wild-time-data==0.1.0 diff --git a/examples/nlp_finetuning/renate_config.py b/examples/nlp_finetuning/renate_config.py index 9cdcd7ca..69055833 100644 --- a/examples/nlp_finetuning/renate_config.py +++ b/examples/nlp_finetuning/renate_config.py @@ -6,7 +6,7 @@ import transformers import renate.defaults as defaults -from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.data.data_module import RenateDataModule from renate.models import RenateModule from renate.models.renate_module import RenateWrapper @@ -28,7 +28,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> """Returns one of two movie review datasets depending on `chunk_id`.""" tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased") dataset_name = "imdb" if chunk_id else "rotten_tomatoes" - data_module = HuggingfaceTextDataModule( + data_module = HuggingFaceTextDataModule( data_path, dataset_name=dataset_name, tokenizer=tokenizer, diff --git a/pyproject.toml b/pyproject.toml index d8a71946..33cd2f27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,13 @@ avalanche = [ "avalanche_lib==0.3.1", "torch>=1.10.0, <1.12.2", ] +benchmark = [ + "wild-time-data==0.1.0", +] dev = [ "black==23.1.0", "avalanche_lib==0.3.1", + "wild-time-data==0.1.0", "torch>=1.10.0, <1.12.2", # PyTest Dependencies "pytest==7.2.2", diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index 70761bd7..f2a594e6 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Optional +from typing import Any, Dict, Optional import datasets import torch @@ -13,7 +13,7 @@ class _InputTargetWrapper(torch.utils.data.Dataset): - """Make a huggingface dataset comply with the `(input, target)` format.""" + """Make a Hugging Face dataset comply with the `(input, target)` format.""" def __init__(self, dataset, target_column: str = "label"): self._dataset = dataset @@ -28,10 +28,10 @@ def __getitem__(self, idx): return item, target -class HuggingfaceTextDataModule(RenateDataModule): - """Data module wrapping Huggingface text datasets. +class HuggingFaceTextDataModule(RenateDataModule): + """Data module wrapping Hugging Face text datasets. - This is a convenience wrapper to expose a hugginface dataset as a `RenateDataModule`. Datasets + This is a convenience wrapper to expose a Hugging Face dataset as a `RenateDataModule`. Datasets will be pre-tokenized and will return `input, target = dataset[i]`, where `input` is a dictionary with fields `["input_ids", "attention_mask"]`, and `target` is a tensor. @@ -41,14 +41,15 @@ class HuggingfaceTextDataModule(RenateDataModule): Args: data_path: the path to the folder containing the dataset files. + tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. dataset_name: Name of the dataset, see https://huggingface.co/datasets. This is a wrapper for text datasets only. input_column: Name of the column containing the input text. target_column: Name of the column containing the target (e.g., class label). - tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/ - for more information on tokenizers. - tokenize_kwargs: Keyword arguments to be passed to the tokenizer. Typical options are - `max_length`, `padding` and `truncation`. See https://huggingface.co/docs/tokenizers/ + tokenizer_kwargs: Keyword arguments passed when calling the tokenizer's ``__call__`` + function. Typical options are `max_length`, `padding` and `truncation`. + See https://huggingface.co/docs/tokenizers/ for more information on tokenizers. If `None` is passed, this defaults to `{"padding": "max_length", max_length: 128, truncation: True}`. val_size: Fraction of the training data to be used for validation. @@ -58,15 +59,15 @@ class HuggingfaceTextDataModule(RenateDataModule): def __init__( self, data_path: str, + tokenizer: transformers.PreTrainedTokenizer, dataset_name: str = "ag_news", input_column: str = "text", target_column: str = "label", - tokenizer: Optional[transformers.PreTrainedTokenizer] = None, - tokenizer_kwargs: Optional[dict] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, val_size: float = defaults.VALIDATION_SIZE, seed: int = defaults.SEED, ): - super(HuggingfaceTextDataModule, self).__init__( + super(HuggingFaceTextDataModule, self).__init__( data_path=data_path, val_size=val_size, seed=seed, diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py new file mode 100644 index 00000000..d1727c5f --- /dev/null +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -0,0 +1,93 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Any, Dict, Optional, Union + +from transformers import PreTrainedTokenizer +from wild_time_data import load_dataset +from wild_time_data.core import available_time_steps, dataset_classes + +from renate import defaults +from renate.data.data_module import RenateDataModule +from renate.utils.file import download_folder_from_s3 + + +class WildTimeDataModule(RenateDataModule): + """Data module wrapping around the Wild-Time data. + + Huaxiu Yao, Caroline Choi, Bochuan Cao, Yoonho Lee, Pang Wei Koh, Chelsea Finn: + Wild-Time: A Benchmark of in-the-Wild Distribution Shift over Time. NeurIPS 2022 + + Args: + data_path: the path to the folder containing the dataset files. + dataset_name: Name of the wild time dataset. + src_bucket: the name of the s3 bucket. If not provided, downloads the data from original + source. + src_object_name: the folder path in the s3 bucket. + time_step: Time slice to be loaded. + tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. + tokenizer_kwargs: Keyword arguments passed when calling the tokenizer's ``__call__`` + function. + val_size: Fraction of the training data to be used for validation. + seed: Seed used to fix random number generation. + """ + + def __init__( + self, + data_path: Union[Path, str], + dataset_name: str, + src_bucket: Optional[str] = None, + src_object_name: Optional[str] = None, + time_step: int = 0, + tokenizer: Optional[PreTrainedTokenizer] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__( + data_path=data_path, + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) + self._dataset_name = dataset_name + self.time_step = time_step + self._tokenizer = tokenizer + self._tokenizer_kwargs = tokenizer_kwargs + + def prepare_data(self) -> None: + """Download data. + + If s3 bucket is given, the data is downloaded from s3, otherwise from the original source. + """ + if self._src_bucket is None: + load_dataset( + dataset_name=self._dataset_name, + time_step=available_time_steps(self._dataset_name)[0], + split="train", + data_dir=self._data_path, + ) + else: + dst_dir = Path(self._data_path) / dataset_classes[self._dataset_name].file_name + if not dst_dir.exists(): + download_folder_from_s3( + src_bucket=self._src_bucket, + src_object_name=self._src_object_name, + dst_dir=str(dst_dir), + ) + + def setup(self) -> None: + """Set up train, test and val datasets.""" + kwargs = { + "dataset_name": self._dataset_name, + "time_step": available_time_steps(self._dataset_name)[self.time_step], + "data_dir": self._data_path, + "in_memory": self._dataset_name != "fmow", + } + if self._tokenizer: + kwargs["transform"] = lambda x: self._tokenizer(x, **(self._tokenizer_kwargs or {})) + train_data = load_dataset(split="train", **kwargs) + self._train_data, self._val_data = self._split_train_val_data(train_data) + self._test_data = load_dataset(split="test", **kwargs) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 24a12b6b..783d43fd 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -3,11 +3,13 @@ from typing import List, Optional, Tuple, Union import torch +import wild_time_data from torchvision.transforms import transforms from transformers import AutoTokenizer -from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule +from renate.benchmark.datasets.wild_time_data import WildTimeDataModule from renate.benchmark.models import ( MultiLayerPerceptron, ResNet18, @@ -33,6 +35,7 @@ ImageRotationScenario, PermutationScenario, Scenario, + WildTimeScenario, ) from renate.data.data_module import RenateDataModule from renate.models import RenateModule @@ -64,6 +67,7 @@ def model_fn( num_outputs: Optional[int] = None, num_hidden_layers: Optional[int] = None, hidden_size: Optional[Tuple[int]] = None, + dataset_name: Optional[str] = None, pretrained_model_name: Optional[str] = None, ) -> RenateModule: """Returns a model instance.""" @@ -81,6 +85,8 @@ def model_fn( "hidden_size": hidden_size, } ) + elif model_name.startswith("ResNet") and dataset_name in ["FashionMNIST", "MNIST", "yearbook"]: + model_kwargs["gray_scale"] = True elif model_name == "HuggingFaceTransformer": if updater == "Avalanche-iCaRL": raise ValueError("Transformers do not support iCaRL.") @@ -97,6 +103,8 @@ def model_fn( def get_data_module( data_path: str, + src_bucket: Optional[str], + src_object_name: Optional[str], dataset_name: str, val_size: float, seed: int, @@ -110,9 +118,22 @@ def get_data_module( ) if dataset_name in ["CLEAR10", "CLEAR100"]: return CLEARDataModule(data_path, dataset_name=dataset_name, val_size=val_size, seed=seed) + if dataset_name in wild_time_data.list_datasets(): + data_module_kwargs = { + "data_path": data_path, + "src_bucket": src_bucket, + "src_object_name": src_object_name, + "dataset_name": dataset_name, + "val_size": val_size, + "seed": seed, + } + if pretrained_model_name is not None: + data_module_kwargs["tokenizer"] = AutoTokenizer.from_pretrained(pretrained_model_name) + return WildTimeDataModule(**data_module_kwargs) + if dataset_name.startswith("hfd-"): tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) - return HuggingfaceTextDataModule( + return HuggingFaceTextDataModule( data_path=data_path, dataset_name=dataset_name[4:], input_column=input_column, @@ -203,6 +224,10 @@ def get_scenario( chunk_id=chunk_id, seed=seed, ) + if scenario_name == "WildTimeScenario": + return WildTimeScenario( + data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed + ) raise ValueError(f"Unknown scenario `{scenario_name}`.") @@ -219,12 +244,16 @@ def data_module_fn( input_dim: Optional[Tuple[int]] = None, feature_idx: Optional[int] = None, randomness: Optional[float] = None, + src_bucket: Optional[str] = None, + src_object_name: Optional[str] = None, pretrained_model_name: Optional[str] = None, input_column: Optional[str] = None, target_column: Optional[str] = None, ): data_module = get_data_module( data_path=data_path, + src_bucket=src_bucket, + src_object_name=src_object_name, dataset_name=dataset_name, val_size=val_size, seed=seed, @@ -232,6 +261,8 @@ def data_module_fn( input_column=input_column, target_column=target_column, ) + if dataset_name in wild_time_data.list_datasets() and num_tasks is None: + num_tasks = len(wild_time_data.available_time_steps(dataset_name)) return get_scenario( scenario_name=scenario_name, data_module=data_module, @@ -256,9 +287,12 @@ def _get_normalize_transform(dataset_name): def train_transform(dataset_name: str) -> Optional[transforms.Compose]: """Returns a transform function to be used in the training.""" - if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"): + if dataset_name in [ + "MNIST", + "FashionMNIST", + ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None - elif dataset_name in ["CIFAR10", "CIFAR100"]: + if dataset_name in ["CIFAR10", "CIFAR100"]: return transforms.Compose( [ transforms.RandomCrop(32, padding=4), @@ -271,8 +305,11 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]: def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: """Returns a transform function to be used for validation or testing.""" - if dataset_name in ["MNIST", "FashionMNIST"] or dataset_name.startswith("hfd-"): + if dataset_name in [ + "MNIST", + "FashionMNIST", + ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None - elif dataset_name in ["CIFAR10", "CIFAR100"]: + if dataset_name in ["CIFAR10", "CIFAR100"]: return _get_normalize_transform(dataset_name) raise ValueError(f"Unknown dataset `{dataset_name}`.") diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 41038fc9..8c25290e 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -377,5 +377,8 @@ def _execute_experiment_job_remotely(experiment_outputs_url: str, **job_kwargs: experiment_outputs_url ), f"experiment_outputs_url {experiment_outputs_url} is not on S3." return submit_remote_job( - source_dir=None, experiment_outputs_url=experiment_outputs_url, **job_kwargs + source_dir=None, + experiment_outputs_url=experiment_outputs_url, + optional_dependencies="benchmark", + **job_kwargs, ) diff --git a/src/renate/benchmark/models/resnet.py b/src/renate/benchmark/models/resnet.py index 051e2fb7..3132c022 100644 --- a/src/renate/benchmark/models/resnet.py +++ b/src/renate/benchmark/models/resnet.py @@ -28,6 +28,7 @@ class ResNet(RenateBenchmarkingModule): replace_stride_with_dilation: Whether to replace the stride with a dilation to save memory. norm_layer: What kind of normalization layer to use, following convolutions. cifar_stem: Whether to use a stem for CIFAR-sized images. + gray_scale: Whether input images are gray-scale images, i.e. only 1 color channel. loss: Loss function to be used for training. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. @@ -46,6 +47,7 @@ def __init__( replace_stride_with_dilation: Optional[List[bool]] = None, norm_layer: Type[nn.Module] = nn.BatchNorm2d, cifar_stem: bool = True, + gray_scale: bool = False, loss: nn.Module = nn.CrossEntropyLoss(), prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, @@ -72,6 +74,7 @@ def __init__( "replace_stride_with_dilation": replace_stride_with_dilation, "norm_layer": norm_layer, "cifar_stem": cifar_stem, + "gray_scale": gray_scale, }, loss_fn=loss, prediction_strategy=prediction_strategy, @@ -81,6 +84,15 @@ def __init__( if cifar_stem: self._backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self._backbone.maxpool = nn.Identity() + if gray_scale: + self._backbone.conv1 = nn.Conv2d( + 1, + self._backbone.conv1.out_channels, + kernel_size=self._backbone.conv1.kernel_size, + stride=self._backbone.conv1.stride, + padding=self._backbone.conv1.padding, + bias=self._backbone.conv1.bias is not None, + ) self._backbone.fc = nn.Identity() for m in self.modules(): diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index 14fe9cf2..7e45c061 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -9,6 +9,7 @@ from torchvision.transforms import Lambda, RandomRotation, ToPILImage from renate import defaults +from renate.benchmark.datasets.wild_time_data import WildTimeDataModule from renate.data.data_module import RenateDataModule from renate.data.datasets import _TransformedDataset from renate.utils.pytorch import get_generator, randomly_split_data @@ -378,3 +379,40 @@ def _get_scores(self, dataset: Dataset) -> List[float]: ) scores.append(value[np.argmax(count)]) return scores + + +class WildTimeScenario(Scenario): + """Creating a time-incremental scenario for the Wild-Time datasets. + + In contrast to the original work, data is presented time step by time step (no grouping) and + the test set is all data up to the current time step. + + Args: + data_module: The source RenateDataModule for the the user data. + num_tasks: The total number of expected tasks for experimentation. + chunk_id: The data chunk to load in for the training or validation data. + seed: Seed used to fix random number generation. + """ + + def __init__( + self, + data_module: RenateDataModule, + num_tasks: int, + chunk_id: int, + seed: int = defaults.SEED, + ) -> None: + super().__init__(data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed) + if not isinstance(data_module, WildTimeDataModule): + raise ValueError("This scenario is only compatible with `WildTimeDataModule`.") + + def setup(self) -> None: + """Sets up the scenario.""" + self._data_module.time_step = self._chunk_id + self._data_module.setup() + self._train_data = self._data_module.train_data() + self._val_data = self._data_module.val_data() + self._test_data = [] + for i in range(self._num_tasks): + self._data_module.time_step = i + self._data_module.setup() + self._test_data.append(self._data_module.test_data()) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 87f2541b..54b38680 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -103,7 +103,7 @@ MEMORY_SIZE = 32 # Benchmark datasets/models -TOKENIZER_KWARGS = dict(padding="max_length", max_length=128, truncation=True) +TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True} def scheduler(config_space: Dict[str, Any], mode: str, metric: str): diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 6aabc6f3..cd606fa7 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -205,7 +205,10 @@ def run_training_job( def _prepare_remote_job( - tmp_dir: str, requirements_file: Optional[str], **job_kwargs: Any + tmp_dir: str, + requirements_file: Optional[str], + optional_dependencies: Optional[str] = None, + **job_kwargs: Any, ) -> List[str]: """Prepares a SageMaker job.""" dependencies = list(renate.__path__ + [job_kwargs["config_file"]]) @@ -223,7 +226,12 @@ def _prepare_remote_job( if requirements_file is None: requirements_file = os.path.join(tmp_dir, "requirements.txt") with open(requirements_file, "w") as f: - f.write(f"renate=={renate.__version__}") + f.write( + "Renate{}=={}".format( + "" if optional_dependencies is None else f"[{optional_dependencies}]", + renate.__version__, + ) + ) dependencies.append(requirements_file) return dependencies @@ -610,6 +618,7 @@ def submit_remote_job( instance_count: int, instance_max_time: float, job_name: str, + optional_dependencies: Optional[str] = None, **job_kwargs: Any, ) -> str: """Executes the training job on SageMaker. @@ -619,7 +628,9 @@ def submit_remote_job( job_timestamp = defaults.current_timestamp() job_name = f"{job_name}-{job_timestamp}" tmp_dir = tempfile.mkdtemp() - dependencies = _prepare_remote_job(tmp_dir=tmp_dir, **job_kwargs) + dependencies = _prepare_remote_job( + tmp_dir=tmp_dir, optional_dependencies=optional_dependencies, **job_kwargs + ) PyTorch( entry_point=tuning_script, source_dir=None if source_dir is None else str(source_dir), diff --git a/test/renate/benchmark/models/test_resnet.py b/test/renate/benchmark/models/test_resnet.py index aed8e528..13f2e006 100644 --- a/test/renate/benchmark/models/test_resnet.py +++ b/test/renate/benchmark/models/test_resnet.py @@ -3,6 +3,7 @@ import pytest import torch +from renate.benchmark.models import ResNet18 from renate.defaults import TASK_ID @@ -67,3 +68,10 @@ def test_renate_resnet_get_params(sub_class, expected_num_params): for i in range(len(first_task_params) - 2, len(first_task_params)): # -2 because the last two parameters are weight and bias of a task specific linear layer assert not torch.equal(first_task_params[i], second_task_params[i]) + + +@pytest.mark.parametrize("gray_scale", (True, False), ids=("gray scale", "rgb")) +def test_renate_resnet_gray_scale_parameter(gray_scale): + """Tests if gray_scale parameter correctly controls number of input channels.""" + expected_in_channels = 1 if gray_scale else 3 + assert ResNet18(gray_scale=gray_scale).get_backbone().conv1.in_channels == expected_in_channels diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 9b71258e..d07f2af3 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -4,7 +4,7 @@ from torchvision.transforms import Compose, Normalize from renate.benchmark import experiment_config -from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule from renate.benchmark.experiment_config import ( data_module_fn, @@ -22,6 +22,7 @@ IIDScenario, ImageRotationScenario, PermutationScenario, + WildTimeScenario, ) from renate.models.prediction_strategies import ICaRLClassificationStrategy @@ -45,6 +46,27 @@ def test_model_fn(model_name, expected_model_class): assert isinstance(model, expected_model_class) +@pytest.mark.parametrize( + "dataset_name,expected_in_channels", + ( + ("FashionMNIST", 1), + ("MNIST", 1), + ("yearbook", 1), + ("CIFAR10", 3), + ("CIFAR100", 3), + ("fmow", 3), + ), +) +def test_model_fn_automatic_input_channel_detection_resnet(dataset_name, expected_in_channels): + """Tests if ResNet architectures input channels are correctly adapted to the dataset.""" + model = model_fn( + model_state_url=None, + model_name="ResNet18", + dataset_name=dataset_name, + ) + assert model.get_backbone().conv1.in_channels == expected_in_channels + + def test_model_fn_fails_for_unknown_model(): unknown_model_name = "UNKNOWN_MODEL_NAME" with pytest.raises(ValueError, match=f"Unknown model `{unknown_model_name}`"): @@ -58,7 +80,7 @@ def test_model_fn_fails_for_unknown_model(): ("CLEAR10", CLEARDataModule, None, None, None), ( "hfd-rotten-tomatoes", - HuggingfaceTextDataModule, + HuggingFaceTextDataModule, "distilbert-base-uncased", "text", "label", @@ -73,6 +95,8 @@ def test_get_data_module( dataset_name=dataset_name, val_size=0.5, seed=0, + src_bucket=None, + src_object_name=None, pretrained_model_name=pretrained_model_name, input_column=input_column, target_column=target_column, @@ -88,6 +112,8 @@ def test_get_data_module_fails_for_unknown_dataset(tmpdir): dataset_name=unknown_dataset_name, val_size=0.5, seed=0, + src_bucket=None, + src_object_name=None, pretrained_model_name=None, input_column=None, target_column=None, @@ -100,6 +126,8 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): dataset_name="MNIST", val_size=0.5, seed=0, + src_bucket=None, + src_object_name=None, pretrained_model_name=None, input_column=None, target_column=None, @@ -166,6 +194,20 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): HueShiftScenario, 3, ), + ( + "WildTimeScenario", + "arxiv", + {"num_tasks": 3, "pretrained_model_name": "distilbert-base-uncased"}, + WildTimeScenario, + 3, + ), + ( + "WildTimeScenario", + "fmow", + {}, + WildTimeScenario, + 16, + ), ), ids=[ "class_incremental", @@ -175,6 +217,8 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "permutation", "feature_sorting", "hue_shift", + "wild_time_text_with_tokenizer", + "wild_time_image_all_tasks", ], ) @pytest.mark.parametrize("val_size", (0, 0.5), ids=["no_val", "val"]) @@ -200,10 +244,15 @@ def test_data_module_fn( if expected_scenario_class == ClassIncrementalScenario: assert scenario._class_groupings == scenario_kwargs["class_groupings"] elif expected_scenario_class == FeatureSortingScenario: - scenario._feature_idx = scenario_kwargs["feature_idx"] - scenario._randomness = scenario_kwargs["randomness"] + assert scenario._feature_idx == scenario_kwargs["feature_idx"] + assert scenario._randomness == scenario_kwargs["randomness"] elif expected_scenario_class == HueShiftScenario: - scenario._randomness = scenario_kwargs["randomness"] + assert scenario._randomness == scenario_kwargs["randomness"] + elif expected_scenario_class == WildTimeScenario: + if "pretrained_model_name" in scenario_kwargs: + assert scenario._data_module._tokenizer is not None + else: + assert scenario._data_module._tokenizer is None assert scenario._num_tasks == expected_num_tasks diff --git a/test/renate/data/test_data_module.py b/test/renate/data/test_data_module.py index 0f83097d..60b8eaa0 100644 --- a/test/renate/data/test_data_module.py +++ b/test/renate/data/test_data_module.py @@ -9,7 +9,7 @@ import transformers from torch.utils.data import Dataset, TensorDataset -from renate.benchmark.datasets.nlp_datasets import HuggingfaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.benchmark.datasets.vision_datasets import ( CLEARDataModule, TinyImageNetDataModule, @@ -153,7 +153,7 @@ def test_tiny_imagenet_data_module(tmpdir): def test_hugging_face_data_module( tmpdir, dataset_name, input_column, target_column, tokenizer, tokenizer_kwargs ): - data_module = HuggingfaceTextDataModule( + data_module = HuggingFaceTextDataModule( data_path=tmpdir, dataset_name=dataset_name, input_column=input_column, @@ -187,11 +187,12 @@ def test_hugging_face_exception_raised_with_wrong_column(tmpdir, column): input_column = "WRONG_COLUMN" elif column == "target": target_column = "WRONG_COLUMN" - data_module = HuggingfaceTextDataModule( + data_module = HuggingFaceTextDataModule( data_path=tmpdir, dataset_name="rotten_tomatoes", input_column=input_column, target_column=target_column, + tokenizer=None, val_size=0.2, ) if column == "input": From 8af30dff7e931ee3b4ff470d12d1df1d4d34d90b Mon Sep 17 00:00:00 2001 From: wistuba Date: Tue, 16 May 2023 10:29:17 +0200 Subject: [PATCH 09/89] Fix `target_column` bug in `HuggingFaceTextDataModule` (#233) --- src/renate/benchmark/datasets/nlp_datasets.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index f2a594e6..906cdaae 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -6,7 +6,6 @@ import datasets import torch import transformers -from datasets import get_dataset_infos from renate import defaults from renate.data.data_module import RenateDataModule @@ -85,7 +84,10 @@ def prepare_data(self) -> None: raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'train' split.") if "test" not in split_names: raise RuntimeError(f"Dataset {self._dataset_name} does not contain a 'test' split.") - available_columns = list(get_dataset_infos(self._dataset_name)["default"].features) + self._train_data = datasets.load_dataset( + self._dataset_name, split="train", cache_dir=self._data_path + ) + available_columns = list(self._train_data.features) if self._input_column not in available_columns: raise ValueError( f"Input column '{self._input_column}' does not exist in {self._dataset_name}. " @@ -96,9 +98,6 @@ def prepare_data(self) -> None: f"Target column '{self._target_column}' does not exist in {self._dataset_name}. " f"Available columns: {available_columns}." ) - self._train_data = datasets.load_dataset( - self._dataset_name, split="train", cache_dir=self._data_path - ) self._test_data = datasets.load_dataset( self._dataset_name, split="test", cache_dir=self._data_path ) @@ -125,13 +124,13 @@ def tokenize_fn(batch): self._train_data = self._train_data.map(tokenize_fn, batched=True) self._train_data.set_format(type="torch", columns=columns) - self._train_data = _InputTargetWrapper(self._train_data) + self._train_data = _InputTargetWrapper(self._train_data, self._target_column) self._test_data = self._test_data.map(tokenize_fn, batched=True) self._test_data.set_format(type="torch", columns=columns) - self._test_data = _InputTargetWrapper(self._test_data) + self._test_data = _InputTargetWrapper(self._test_data, self._target_column) if self._val_data is not None: self._val_data = self._val_data.map(tokenize_fn, batched=True) self._val_data.set_format(type="torch", columns=columns) - self._val_data = _InputTargetWrapper(self._val_data) + self._val_data = _InputTargetWrapper(self._val_data, self._target_column) else: self._train_data, self._val_data = self._split_train_val_data(self._train_data) From b83550a11d95fc58b00f5e6d7a5a3672619b6819 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Tue, 23 May 2023 11:54:08 +0200 Subject: [PATCH 10/89] Add MMD covariate shift detector (#237) --- src/renate/shift/__init__.py | 0 src/renate/shift/detector.py | 112 ++++++++++++++++++++++++ src/renate/shift/kernels.py | 41 +++++++++ src/renate/shift/mmd_detectors.py | 49 +++++++++++ src/renate/shift/mmd_helpers.py | 90 +++++++++++++++++++ test/renate/shift/test_detector.py | 20 +++++ test/renate/shift/test_kernels.py | 48 ++++++++++ test/renate/shift/test_mmd_detectors.py | 29 ++++++ test/renate/shift/test_mmd_helpers.py | 30 +++++++ 9 files changed, 419 insertions(+) create mode 100644 src/renate/shift/__init__.py create mode 100644 src/renate/shift/detector.py create mode 100644 src/renate/shift/kernels.py create mode 100644 src/renate/shift/mmd_detectors.py create mode 100644 src/renate/shift/mmd_helpers.py create mode 100644 test/renate/shift/test_detector.py create mode 100644 test/renate/shift/test_kernels.py create mode 100644 test/renate/shift/test_mmd_detectors.py create mode 100644 test/renate/shift/test_mmd_helpers.py diff --git a/src/renate/shift/__init__.py b/src/renate/shift/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/renate/shift/detector.py b/src/renate/shift/detector.py new file mode 100644 index 00000000..c40c3839 --- /dev/null +++ b/src/renate/shift/detector.py @@ -0,0 +1,112 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch.utils.data import Dataset, DataLoader + +from typing import Optional + +from renate.utils.pytorch import move_tensors_to_device + + +class ShiftDetector: + """Base class for distribution shift detectors. + + The main interface consists of two methods `fit` and `score`, which expect pytorch Dataset + objects. One passes a reference dataset to the `fit` method. Then we can check query datasets + for distribution shifts (relative to the reference dataset) using the `score` method. The + `score` method returns a scalar shift score with the convention that high values indicate a + distribution shift. For most methods, this score will be in [0, 1]. + + Args: + batch_size: Batch size used to iterate over datasets, e.g., for extracting features. This + choice does not affect the result of the shift detector, but might affect run time. + num_preprocessing_workers: Number of workers used in data loaders. + device: Device to use for computations inside the detector. + """ + + def __init__( + self, + batch_size: int = 32, + num_preprocessing_workers: int = 0, + device: str = "cpu", + ) -> None: + self._batch_size = batch_size + self._num_preprocessing_workers = num_preprocessing_workers + self._device = device + + def fit(self, dataset: Dataset) -> None: + """Fit the detector to a reference dataset.""" + raise NotImplementedError() + + def score(self, dataset: Dataset) -> float: + """Compute distribution shift score for a query dataset.""" + raise NotImplementedError() + + def _make_data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + """Return a data loader to iterate over a dataset. + + Args: + dataset: The dataset. + shuffle: Whether to shuffle or not. + """ + return DataLoader( + dataset, + batch_size=self._batch_size, + shuffle=shuffle, + num_workers=self._num_preprocessing_workers, + ) + + +class ShiftDetectorWithFeatureExtractor(ShiftDetector): + """Base class for detectors working on extracted features. + + These shift detectors extract some (lower-dimensional) features from the datasets, which are + used as inputs to the shift detection methods. Subclasses have to overwrite `fit_with_features` + and `score_with_features`. + + Args: + feature_extractor: A pytorch model used as feature extractor. + batch_size: Batch size used to iterate over datasets. + num_preprocessing_workers: Number of workers used in data loaders. + device: Device to use for computations inside the detector. + """ + + def __init__( + self, + feature_extractor: Optional[torch.nn.Module] = None, + batch_size: int = 32, + num_preprocessing_workers: int = 0, + device: str = "cpu", + ) -> None: + super(ShiftDetectorWithFeatureExtractor, self).__init__( + batch_size, num_preprocessing_workers, device + ) + self._feature_extractor = feature_extractor or torch.nn.Identity() + self._feature_extractor = self._feature_extractor.to(self._device) + + def fit(self, dataset: Dataset) -> None: + """Fit the detector to a reference dataset.""" + X = self.extract_features(dataset) + self._fit_with_features(X) + + def score(self, dataset: Dataset) -> float: + """Compute distribution shift score for a query dataset.""" + X = self.extract_features(dataset) + return self._score_with_features(X) + + @torch.no_grad() + def extract_features(self, dataset: Dataset) -> torch.Tensor: + """Extract features from a dataset.""" + dataloader = self._make_data_loader(dataset) + Xs = [] + for batch in dataloader: + X = move_tensors_to_device(batch[0], device=self._device) + Xs.append(self._feature_extractor(X)) + X = torch.cat(Xs, dim=0).cpu() + return X + + def _fit_with_features(self, X: torch.Tensor) -> None: + raise NotImplementedError() + + def _score_with_features(self, X: torch.Tensor) -> float: + raise NotImplementedError() diff --git a/src/renate/shift/kernels.py b/src/renate/shift/kernels.py new file mode 100644 index 00000000..5e67846f --- /dev/null +++ b/src/renate/shift/kernels.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + + +class Kernel: + """Base class for kernel functions.""" + + def __init__(self): + pass + + def _check_inputs(self, X0: torch.Tensor, X1: torch.Tensor): + assert X0.dim() == X1.dim() == 2 + assert X0.size(1) == X1.size(1) + assert X0.dtype is X1.dtype + + +class RBFKernel(Kernel): + """A radial basis function kernel. + + This kernel has one hyperparameter, a scalar lengthscale. If this is set to `None` (default), + the lengthscale will be set adaptively, at _each_ call to the kernel, via the median heuristic. + + Args: + lengthscale: The kernel lengthscale. If `None` (default), this is set automatically via the + median heuristic. Note: In this case, the lengthscale will be reset at each call to the + kernel. + """ + + def __init__(self, lengthscale: Optional[float] = None): + super().__init__() + self._lengthscale = lengthscale + + @torch.no_grad() + def __call__(self, X0: torch.Tensor, X1: torch.Tensor): + self._check_inputs(X0, X1) + dists = torch.cdist(X0, X1) + lengthscale = self._lengthscale or torch.median(dists) + return torch.exp(-0.5 * dists**2 / lengthscale**2) diff --git a/src/renate/shift/mmd_detectors.py b/src/renate/shift/mmd_detectors.py new file mode 100644 index 00000000..911b516e --- /dev/null +++ b/src/renate/shift/mmd_detectors.py @@ -0,0 +1,49 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import torch + +from renate.shift.detector import ShiftDetectorWithFeatureExtractor +from renate.shift.kernels import RBFKernel +from renate.shift.mmd_helpers import mmd + + +class MMDCovariateShiftDetector(ShiftDetectorWithFeatureExtractor): + """A kernel maximum mean discrepancy (MMD) test. + + This test was proposed by + + [1] Gretton, A., et al. A kernel two-sample test. JMLR (2012). + + We currently do not expose the choice of kernel. It defaults to an RBF kernel with a lengthscale + set via the median heuristic. + + The detector computes an approximate p-value via a permutation test. The `score` method returns + `1 - p_value` to conform to the convention that high scores indicate a shift. + + Args: + feature_extractor: A pytorch model used as feature extractor. + num_permutations: Number of permutations for permutation test. + batch_size: Batch size used to iterate over datasets. + num_preprocessing_workers: Number of workers used in data loaders. + device: Device to use for computations inside the detector. + """ + + def __init__( + self, + feature_extractor: Optional[torch.nn.Module] = None, + num_permutations: int = 1000, + batch_size: int = 32, + num_preprocessing_workers: int = 0, + device: str = "cpu", + ) -> None: + super().__init__(feature_extractor, batch_size, num_preprocessing_workers, device) + self._num_permutations = num_permutations + + def _fit_with_features(self, X: torch.Tensor): + self._X_ref = X + + def _score_with_features(self, X: torch.Tensor) -> float: + _, p_val = mmd(self._X_ref, X, kernel=RBFKernel(), num_permutations=self._num_permutations) + return 1.0 - p_val.item() diff --git a/src/renate/shift/mmd_helpers.py b/src/renate/shift/mmd_helpers.py new file mode 100644 index 00000000..4a3da536 --- /dev/null +++ b/src/renate/shift/mmd_helpers.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + +import torch + +from renate.shift.kernels import Kernel + + +def mmd_gram( + K: torch.Tensor, z: torch.Tensor, num_permutations: int = 0 +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Maximum mean discrepancy based on a precomputed kernel-Gram matrix. + + This computes the test statistic and (optionally) p-value to conduct an MMD two-sample test to + decide whether two sets are generated by the same distribution. The inputs are passed implicitly + in the form of a kernel Gram matrix, evaluated across the union of both sets, and a binary + vector indicating the assignments of data points to the two sets. I.e., a value of `z[i] = 0` + indicates that the `i`-th data point belongs to set zero. Optionally, a permutation test is + carried out and a p-value is returned alongside the raw test statistic. + + MMD tests have been proposed by + + [1] Gretton, A., et al. A kernel two-sample test. JMLR (2012). + + Args: + K: A tensor containing a kernel-Gram matrix (size `(n, n)`). + z: A binary vector of length `n` indicating the partition. + num_permutations: If this is a positive number, a permutation test will be carried out and + an approximate p-value will be returned. + + Returns: + A tuple `(t, p)` of two scalar floats, where `t` is the value of the MMD test statistic and + `p` is the p-value (or `None` if `num_permutations=0`). + """ + n = K.size(0) + assert K.size(1) == n + assert z.size() == (n,) + inds_0 = torch.where(z == 0)[0] + inds_1 = torch.where(z == 1)[0] + n0 = len(inds_0) + n1 = len(inds_1) + + mmd = ( + K[inds_0][:, inds_0].sum() / (n0 * (n0 - 1)) + + K[inds_1][:, inds_1].sum() / (n1 * (n1 - 1)) + - 2 * K[inds_0][:, inds_1].mean() + ) # MMD statistic, see Eq. (5) in [1]. + + if num_permutations == 0: + return mmd, None + + # Permutation test: Randomize the assignments z, compute MMD, and count how often we exceed the + # value obtained with original z. + cnt = 0 + for _ in range(num_permutations): + z_ = torch.zeros(n, device=K.device) + pi = torch.randperm(n, device=K.device) + z_[pi[:n1]] = 1 + mmd0, _ = mmd_gram(K, z_, num_permutations=0) + if mmd0 > mmd: + cnt += 1 + p_val = torch.tensor(cnt / num_permutations) + + return mmd, p_val + + +def mmd( + X0: torch.Tensor, X1: torch.Tensor, kernel: Kernel, num_permutations: int = 0 +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Compute MMD between two samples. + + Optionally, return an estimated p-value based on a permutation test. + + Args: + X0: First sample, shape (n0, dx). + X1: Second sample, shape (n1, dx). + kernel: Kernel function to use. + num_permutations: If this is a positive number, a permutation test will + be carried out and a p-value will be returned. + + Returns: + A tuple `(t, p)` where `t` is the value of the MMD test statistic and `p` is the p-value + (or `None` if `num_permutations=0`). + """ + X = torch.cat([X0, X1], dim=0) + z = torch.cat([torch.zeros(X0.size(0)), torch.ones(X1.size(0))], dim=0) + K = kernel(X, X) + assert not torch.any(torch.isnan(K)) + return mmd_gram(K, z, num_permutations) diff --git a/test/renate/shift/test_detector.py b/test/renate/shift/test_detector.py new file mode 100644 index 00000000..10436aa2 --- /dev/null +++ b/test/renate/shift/test_detector.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.shift.detector import ShiftDetectorWithFeatureExtractor + + +@pytest.mark.parametrize( + "dataset", + [ + torch.utils.data.TensorDataset(torch.randn(20, 2)), + torch.utils.data.TensorDataset(torch.randn(20, 2), torch.randint(0, 10, size=(20,))), + ], +) +@pytest.mark.parametrize("feature_extractor", [torch.nn.Linear(2, 4)]) +def test_extract_features(dataset, feature_extractor): + detector = ShiftDetectorWithFeatureExtractor(feature_extractor=feature_extractor) + features = detector.extract_features(dataset) + assert features.size() == (20, 4) diff --git a/test/renate/shift/test_kernels.py b/test/renate/shift/test_kernels.py new file mode 100644 index 00000000..909e63d6 --- /dev/null +++ b/test/renate/shift/test_kernels.py @@ -0,0 +1,48 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.shift.kernels import RBFKernel + + +@pytest.mark.parametrize("kernel", [RBFKernel()]) +@pytest.mark.parametrize( + "X1, X2", + [ + (torch.randn(20, 2), torch.randn(20, 2)), + (torch.randn(10, 2), torch.randn(20, 2)), + (torch.randn(20, 2), torch.randn(10, 2)), + ], +) +def test_kernel_shapes(kernel, X1, X2): + K = kernel(X1, X2) + assert K.size() == (X1.size(0), X2.size(0)) + + +@pytest.mark.parametrize("kernel", [RBFKernel()]) +def test_kernel_shape_mismatch(kernel): + with pytest.raises(Exception): + kernel(torch.randn(10, 2), torch.randn(20, 3)) + + +def test_rbf_vs_manual_computation(): + X0 = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + X1 = torch.tensor([[1.0, 0.0], [2.0, 2.0]]) + kernel = RBFKernel(lengthscale=1.0) + K = kernel(X0, X1) + K_exp = torch.exp(-0.5 * torch.tensor([[1.0, 8.0], [2.0, 5.0]])) + assert torch.allclose(K, K_exp) + + +def test_rbf_kernel_limits(): + """Tests limit behavior of RBF kernel""" + X = torch.randn(10, 2) + # Small lengthscales should result in vanishing off-diagonal terms. + kernel = RBFKernel(lengthscale=1e-8) + K = kernel(X, X) + assert torch.allclose(K, torch.eye(10)) + # High lengthscales should result in a matrix of all ones. + kernel = RBFKernel(lengthscale=1e8) + K = kernel(X, X) + assert torch.allclose(K, torch.ones(10)) diff --git a/test/renate/shift/test_mmd_detectors.py b/test/renate/shift/test_mmd_detectors.py new file mode 100644 index 00000000..d7dc67d1 --- /dev/null +++ b/test/renate/shift/test_mmd_detectors.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.shift.mmd_detectors import MMDCovariateShiftDetector + + +@pytest.mark.parametrize( + "detector", [MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100)] +) +def test_shift_detector_identical_data(detector): + """We expect low scores for identical data.""" + dataset = torch.utils.data.TensorDataset(torch.randn(100, 2)) + detector.fit(dataset) + score = detector.score(dataset) + assert score == 0.0 + + +@pytest.mark.parametrize( + "detector", [MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100)] +) +def test_shift_detector_disjoint_data(detector): + """We expect high scores for very different data (two disjoint Gaussian blobs).""" + dataset_ref = torch.utils.data.TensorDataset(torch.randn(100, 2)) + dataset_query = torch.utils.data.TensorDataset(torch.randn(100, 2) + 2.0) + detector.fit(dataset_ref) + score = detector.score(dataset_query) + assert score == 1.0 diff --git a/test/renate/shift/test_mmd_helpers.py b/test/renate/shift/test_mmd_helpers.py new file mode 100644 index 00000000..a2b36124 --- /dev/null +++ b/test/renate/shift/test_mmd_helpers.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.shift.kernels import RBFKernel +from renate.shift.mmd_helpers import mmd + + +def test_mmd_identical_data(): + """We expect high p-values for identical data.""" + X = torch.randn(100, 2) + _, p_val = mmd(X, X, kernel=RBFKernel(), num_permutations=100) + assert p_val == 1.0 + + +def test_shift_detector_disjoint_data(): + """We expect low p-values for very different data (two disjoint Gaussian blobs).""" + X0 = torch.randn(100, 2) + X1 = torch.randn(100, 2) + 2.0 + _, p_val = mmd(X0, X1, kernel=RBFKernel(), num_permutations=100) + assert p_val == 0.0 + + +def test_mmd_vs_manual(): + X0 = torch.tensor([[0.0, 0.0], [0.0, 1.0]]) + X1 = torch.tensor([[1.0, 0.0], [2.0, 2.0]]) + mmd_val, _ = mmd(X0, X1, kernel=RBFKernel(lengthscale=1.0), num_permutations=0) + # Compare against manual computation of the terms in the MMD formula. + assert mmd_val.item() == pytest.approx(1.6065 + 1.0821 - 2 * 0.2687, abs=1e-4) From 944a23ca822bc5a49a9d026eaaeceeaf78167489 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Tue, 23 May 2023 14:07:03 +0200 Subject: [PATCH 11/89] Add KS covariate shift detector (#242) --- requirements.txt | 1 + src/renate/shift/ks_detector.py | 32 +++++++++++++++++++ ...test_detector.py => test_detector_base.py} | 0 ...est_mmd_detectors.py => test_detectors.py} | 13 ++++++-- 4 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 src/renate/shift/ks_detector.py rename test/renate/shift/{test_detector.py => test_detector_base.py} (100%) rename test/renate/shift/{test_mmd_detectors.py => test_detectors.py} (66%) diff --git a/requirements.txt b/requirements.txt index ff626d22..65dbff7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ torchmetrics~=0.10.3 torchvision>=0.13.0, <0.15.2 datasets~=2.9.0 transformers>4.23.0, <4.26.2 +scipy>=1.9.0, <1.10.2 diff --git a/src/renate/shift/ks_detector.py b/src/renate/shift/ks_detector.py new file mode 100644 index 00000000..6f1431c8 --- /dev/null +++ b/src/renate/shift/ks_detector.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from scipy.stats import kstest +import torch + +from renate.shift.detector import ShiftDetectorWithFeatureExtractor + + +class KolmogorovSmirnovCovariateShiftDetector(ShiftDetectorWithFeatureExtractor): + """A Kolmogorov-Smirnov (KS) test on each feature. + + A KS test is a univariate two-sample test, which we perform separately for each feature. To + aggregate these tests without running into multiple-testing problems, we use a Bonferroni + correction, as proposed in + + Rabanser et al. Failing Loudly: An Empirical Study of Methods for Detecting Dataset Shift. + NeurIPS 2019. + """ + + def _fit_with_features(self, X: torch.Tensor) -> None: + self._X_ref = X + + def _score_with_features(self, X: torch.Tensor) -> float: + n_features = X.size(1) + p_vals = [ + kstest(X[:, i].numpy(), self._X_ref[:, i].numpy()).pvalue for i in range(n_features) + ] + # Bonferroni correction: Reject only if the minimal p-value among the multiple tests is + # lower than `alpha / num_tests`, where `alpha` is the significance level. Equivalently, we + # multiple the p-value by `num_tests`. + p_val = min(1.0, min(p_vals) * n_features) + return 1.0 - p_val diff --git a/test/renate/shift/test_detector.py b/test/renate/shift/test_detector_base.py similarity index 100% rename from test/renate/shift/test_detector.py rename to test/renate/shift/test_detector_base.py diff --git a/test/renate/shift/test_mmd_detectors.py b/test/renate/shift/test_detectors.py similarity index 66% rename from test/renate/shift/test_mmd_detectors.py rename to test/renate/shift/test_detectors.py index d7dc67d1..82ffd259 100644 --- a/test/renate/shift/test_mmd_detectors.py +++ b/test/renate/shift/test_detectors.py @@ -4,10 +4,15 @@ import torch from renate.shift.mmd_detectors import MMDCovariateShiftDetector +from renate.shift.ks_detector import KolmogorovSmirnovCovariateShiftDetector @pytest.mark.parametrize( - "detector", [MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100)] + "detector", + [ + MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100), + KolmogorovSmirnovCovariateShiftDetector(feature_extractor=None), + ], ) def test_shift_detector_identical_data(detector): """We expect low scores for identical data.""" @@ -18,7 +23,11 @@ def test_shift_detector_identical_data(detector): @pytest.mark.parametrize( - "detector", [MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100)] + "detector", + [ + MMDCovariateShiftDetector(feature_extractor=None, num_permutations=100), + KolmogorovSmirnovCovariateShiftDetector(feature_extractor=None), + ], ) def test_shift_detector_disjoint_data(detector): """We expect high scores for very different data (two disjoint Gaussian blobs).""" From 2ab5d6e42dd1e3d7258519ce72b41989fb9cf8e7 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 24 May 2023 10:04:05 +0200 Subject: [PATCH 12/89] Update dependabot.yml (#248) --- .github/dependabot.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index fca8a326..f9a42aef 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -8,6 +8,7 @@ updates: schedule: interval: "weekly" target-branch: "dev" + open-pull-requests-limit: 20 - package-ecosystem: "github-actions" directory: "/" schedule: From 0198caa88bb567361f0dfd2e9a5e17f286ec0fd6 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 24 May 2023 10:05:14 +0200 Subject: [PATCH 13/89] Update versions of some requirements (#247) --- doc/requirements.txt | 2 +- requirements.txt | 18 +++++++++--------- src/renate/training/training.py | 2 +- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 6cf5cc43..8d7193be 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -3,7 +3,7 @@ Sphinx==6.1.3 sphinx-copybutton==0.5.1 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.1 -pydata-sphinx-theme==0.13.1 +pydata-sphinx-theme==0.13.3 sphinx-autodoc-typehints==1.22.0 sphinx-paramlinks==0.5.4 diff --git a/requirements.txt b/requirements.txt index 65dbff7f..c9ce9bc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,15 +1,15 @@ -numpy>=1.17.2, <1.24.2 +numpy>=1.17.2, <1.24.4 torch>=1.10.0, <1.13.2 -pandas>=1.4.0, <1.5.3 -boto3>=1.26.0, <1.26.116 -requests>=2.28.0, <2.28.2 -sagemaker>=2.112.0, <2.133.1 -syne-tune[aws,gpsearchers]==0.4.1 -pytorch-lightning~=1.8.6 +pandas>=1.4.0, <2.0.2 +boto3>=1.26.0, <1.26.139 +requests>=2.31.0, <2.31.1 +sagemaker>=2.112.0, <2.158.1 +syne-tune[aws,gpsearchers]==0.6.0 +pytorch-lightning>=1.8.0, <1.9.5 Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 torchmetrics~=0.10.3 torchvision>=0.13.0, <0.15.2 -datasets~=2.9.0 -transformers>4.23.0, <4.26.2 +datasets>=2.9.0, < 2.12.1 +transformers>4.23.0, <4.29.3 scipy>=1.9.0, <1.10.2 diff --git a/src/renate/training/training.py b/src/renate/training/training.py index cd606fa7..329c328e 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -24,9 +24,9 @@ RUSHScheduler, TransferLearningTaskEvaluations, ) +from syne_tune.results_callback import StoreResultsCallback from syne_tune.stopping_criterion import StoppingCriterion from syne_tune.tuner import Tuner -from syne_tune.tuner_callback import StoreResultsCallback from syne_tune.util import experiment_path import renate From f3da9e4895eacbd6d03b8bcb309735e018b4f349 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Wed, 24 May 2023 10:45:49 +0200 Subject: [PATCH 14/89] Add doc page and example for shift detection (#244) --- doc/getting_started/index.rst | 3 +- doc/getting_started/shift_detection.rst | 97 +++++++++++++++++++ .../shift_detection/image_shift_detection.py | 47 +++++++++ src/renate/shift/ks_detector.py | 2 +- 4 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 doc/getting_started/shift_detection.rst create mode 100644 examples/shift_detection/image_shift_detection.py diff --git a/doc/getting_started/index.rst b/doc/getting_started/index.rst index ac344f0a..93e4a890 100644 --- a/doc/getting_started/index.rst +++ b/doc/getting_started/index.rst @@ -7,7 +7,7 @@ of a model. The content is intended to explain the basic steps to be taken when training pipeline based on Renate. If your goal is to test multiple Renate algorithms on your dataset or to test a specific algorithm on -multiple datasets and scenarios, you will be better served by the benchmarking functionalities +multiple datasets and scenarios, you will be better served by the benchmarking functionalities provided in :doc:`../benchmarking/renate_benchmarks`. .. toctree:: @@ -19,3 +19,4 @@ provided in :doc:`../benchmarking/renate_benchmarks`. output supported_algorithms avalanche + shift_detection diff --git a/doc/getting_started/shift_detection.rst b/doc/getting_started/shift_detection.rst new file mode 100644 index 00000000..03264483 --- /dev/null +++ b/doc/getting_started/shift_detection.rst @@ -0,0 +1,97 @@ +Distribution Shift Detection +**************************** + +Retraining or updating of a machine learning model is usually necessitated by *shifts* in the +distribution of data that is being served to the model. +Renate provides methods for distribution shift detection that can help you decide when to update +your model. +This functionality resides in the :py:mod:`renate.shift` subpackage. + +Shift Types +=========== + +In supervised machine learning tasks, one can distinguish different types of shifts in the joint +distribution :math:`p(x, y)`. +A common assumption is that of *covariate shift*, where we assume that :math:`p(x)` changes while +:math:`p(y|x)` stays constant. +In that case, one only needs to inspect :math:`x` data to detect a shift. +Currently, Renate only supports covariate shift detection. + +Shift Detector Interface +======================== + +The shift detectors in :py:mod:`renate.shift` derive from a common class +:py:class:`~renate.shift.detector.ShiftDetector`, which defines the main interface. Once a +:code:`detector` object has been initialized, one calls :code:`detector.fit(dataset_ref)` on a +reference dataset (a PyTorch dataset object). This reference dataset characterizes the expected +data distribution. It may, e.g., be the validation set used during the previous fitting of the +model. Subsequently, we can score one or multiple query datasets using the +:code:`detector.score(dataset_query)` method. This method returns a scalar distribution shift score. +We use the convention that high scores indicate a likely distribution shift. For all currently +available models, this score lies between 0 and 1. + +Available Methods +================= + +At the moment, Renate provides two method for covariate shift detection + +* :py:class:`~renate.shift.mmd_detectors.MMDCovariateShiftDetector` uses a multivariate kernel MMD + test. +* :py:class:`~renate.shift.ks_detector.KolmogorovSmirnovCovariateShiftDetector` uses a univariate + Kolmogorov-Smirnov test on each feature, aggregated with a Bonferroni correction. + +Both tests operate on features extracted from the raw data, which is passed using the +:code:`feature_extractor` argument at initialization. The feature extractor is expected to map the +raw input data to informative vectorial representations of moderate dimension. It may be based on +a pretrained model, e.g., by using its penultimate-layer embeddings (see also the example below). + + +Example +======= + +The following example illustrates how to apply the MMD covariate shift detector. +We will work with the CIFAR-10 dataset, which we can conveniently load using Renate's +:py:class:`~renate.benchmark.datasets.vision_datasets.TorchVisionDataModule`. +In practice, you would ingest your own data here, see the documentation for +:py:class:`~renate.data.data_module.RenateDataModule`. + +.. literalinclude:: ../../examples/shift_detection/image_shift_detection.py + :lines: 13-16 + +For the purpose of this demonstration, we now generate a reference dataset as well as two query +datasets: one from the same distribution, and one where we simulate a distribution shift by +blurring images. +In practice, the reference dataset should represent your expected data distribution. +It could, e.g., be the validation set you used during the previous training of your model. +The query dataset would be the data you want to check for distribution shift, e.g., data collected +during the deployment of your model. + +.. literalinclude:: ../../examples/shift_detection/image_shift_detection.py + :lines: 22-26 + +Shift detection methods rely on informative (and relatively low-dimensional) features. +Here, we use a pretrained ResNet model and chop of its output layer. +This leads to 512-dimensional vectorial features. + +.. literalinclude:: ../../examples/shift_detection/image_shift_detection.py + :lines: 31-33 + +You can use any :py:class:`torch.nn.Module`, which may be a pretrained model or use a custom model +that has been trained on the data at hand. +Generally, we have observed very good result when using generic pre-trained models such as ResNets +for image data or BERT models for text. + +Now we can instantiate an MMD-based shift detector. We first fit it to our reference datasets and +then score both the in-distribution query dataset as well as the out-of-distribution query dataset. + +.. literalinclude:: ../../examples/shift_detection/image_shift_detection.py + :lines: 39-47 + +In this toy example, the shift is quite obvious and we will see a very high score for the +out-of-distribution data:: + + Fitting detector... + Scoring in-distribution data... + score = 0.5410000085830688 + Scoring out-of-distribution data... + score = 1.0 diff --git a/examples/shift_detection/image_shift_detection.py b/examples/shift_detection/image_shift_detection.py new file mode 100644 index 00000000..f0b069fb --- /dev/null +++ b/examples/shift_detection/image_shift_detection.py @@ -0,0 +1,47 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import torch +from torchvision.models import resnet18, ResNet18_Weights +from torchvision.transforms import GaussianBlur + +from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule +from renate.data.datasets import _TransformedDataset +from renate.shift.mmd_detectors import MMDCovariateShiftDetector + + +# Load CIFAR-10 training dataset. +data_module = TorchVisionDataModule(data_path="data", dataset_name="CIFAR10", val_size=0.2) +data_module.prepare_data() +data_module.setup() +dataset = data_module.train_data() + +# We now generate a reference dataset as well as two query datasets: one from the same distribution, +# and one where we simulate a distribution shift by blurring images. In practice, the reference +# dataset should represent your expected data distribution. It could, e.g., be the validation set +# you used during the previous training of your model. +dataset_ref = torch.utils.data.Subset(dataset, list(range(1000))) +dataset_query_in = torch.utils.data.Subset(dataset, list(range(1000, 2000))) +dataset_query_out = torch.utils.data.Subset(dataset, list(range(2000, 3000))) +transform = GaussianBlur(kernel_size=5, sigma=1.0) +dataset_query_out = _TransformedDataset(dataset_query_in, transform) + +# Shift detection methods rely on informative (and relatively low-dimensional) features. Here, we +# use a pretrained ResNet model and chop of its output layer. This leads to 512-dimensional +# vectorial features. +feature_extractor = resnet18(weights=ResNet18_Weights.DEFAULT) +feature_extractor.fc = torch.nn.Identity() +feature_extractor.eval() # Eval mode to use frozen batchnorm stats. + +# Now we can instantiate an MMD-based shift detector. We first fit it to our reference datasets, +# and then score both the in-distribution query dataset and the out-of-distribution query dataset. +# In this toy example, the shift is quite obvious and we will see a very high score for the +# out-of-distribution data. +detector = MMDCovariateShiftDetector(feature_extractor=feature_extractor) +print("Fitting detector...") +detector.fit(dataset_ref) +print("Scoring in-distribution data...") +score_in = detector.score(dataset_query_in) +print(f"score = {score_in}") +print("Scoring out-of-distribution data...") +score_out = detector.score(dataset_query_out) +print(f"score = {score_out}") diff --git a/src/renate/shift/ks_detector.py b/src/renate/shift/ks_detector.py index 6f1431c8..484d94dd 100644 --- a/src/renate/shift/ks_detector.py +++ b/src/renate/shift/ks_detector.py @@ -27,6 +27,6 @@ def _score_with_features(self, X: torch.Tensor) -> float: ] # Bonferroni correction: Reject only if the minimal p-value among the multiple tests is # lower than `alpha / num_tests`, where `alpha` is the significance level. Equivalently, we - # multiple the p-value by `num_tests`. + # multiply the p-value by `num_tests`. p_val = min(1.0, min(p_vals) * n_features) return 1.0 - p_val From 5033464c38d6d2bdc73933a1eb6c6c28cf7eab1b Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 24 May 2023 11:28:45 +0200 Subject: [PATCH 15/89] Bump version (#252) --- src/renate/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renate/__init__.py b/src/renate/__init__.py index eaa31c36..5deff2eb 100644 --- a/src/renate/__init__.py +++ b/src/renate/__init__.py @@ -14,4 +14,4 @@ _renate_logger.addHandler(_handler) _renate_logger.propagate = False -__version__ = "0.2.1" +__version__ = "0.3.0" From 4e335a039bc261b929ac332a68e4a70e858812ee Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 24 May 2023 19:16:01 +0200 Subject: [PATCH 16/89] Merge Main into Dev (#264) * Add NLP Components to Benchmarking (#213) * Robust Integration Tests (#214) * Update Renate Config Example (#226) * Make Wild Time Available in Benchmarking (#187) * Fix `target_column` bug in `HuggingFaceTextDataModule` (#233) * Add MMD covariate shift detector (#237) * Add KS covariate shift detector (#242) * Update dependabot.yml (#248) * Update versions of some requirements (#247) * Add doc page and example for shift detection (#244) * Bump version (#252) --------- Co-authored-by: Lukas Balles From 19a22719b84b736985dd1bea50717de8edb9f45e Mon Sep 17 00:00:00 2001 From: Prabhu Teja S Date: Thu, 25 May 2023 16:01:55 +0200 Subject: [PATCH 17/89] MultiGPU training + changes to Checkpointing logic (#218) Signed-off-by: Prabhu Teja S Co-authored-by: Prabhu Teja --- doc/examples/nlp_finetuning.rst | 35 ++++ doc/getting_started/how_to_renate_config.rst | 23 ++- examples/getting_started/renate_config.py | 10 +- examples/nlp_finetuning/renate_config.py | 6 +- examples/nlp_finetuning/start.py | 3 + requirements.txt | 1 + .../benchmark/datasets/vision_datasets.py | 8 +- src/renate/benchmark/experiment_config.py | 4 + src/renate/benchmark/experimentation.py | 19 +- src/renate/benchmark/models/base.py | 30 +-- src/renate/benchmark/models/mlp.py | 3 - src/renate/benchmark/models/resnet.py | 3 - src/renate/benchmark/models/transformer.py | 4 - .../benchmark/models/vision_transformer.py | 3 - src/renate/cli/parsing_functions.py | 16 ++ src/renate/cli/run_training.py | 16 +- src/renate/defaults.py | 2 + src/renate/evaluation/evaluator.py | 6 + src/renate/models/renate_module.py | 38 ++-- src/renate/training/training.py | 16 +- src/renate/updaters/avalanche/learner.py | 4 +- .../updaters/avalanche/model_updater.py | 24 +++ src/renate/updaters/experimental/er.py | 171 +++++------------- .../updaters/experimental/fine_tuning.py | 7 + src/renate/updaters/experimental/gdumb.py | 19 +- src/renate/updaters/experimental/joint.py | 36 ++-- .../updaters/experimental/offline_er.py | 27 +-- .../updaters/experimental/repeated_distill.py | 8 + src/renate/updaters/learner.py | 155 ++++------------ src/renate/updaters/model_updater.py | 91 ++++++++-- src/renate/utils/deepspeed.py | 112 ++++++++++++ src/renate/utils/distributed_strategies.py | 58 ++++++ src/renate/utils/file.py | 15 ++ src/renate/utils/misc.py | 14 ++ src/renate/utils/module.py | 10 + test/conftest.py | 50 ++++- test/integration_tests/run_experiment.py | 4 +- test/renate/benchmark/test_experimentation.py | 4 +- test/renate/models/test_renate_module.py | 20 +- test/renate/renate_config_files/config.py | 4 + .../renate_config_files/config_scenario.py | 4 + test/renate/training/test_run_training.py | 2 +- .../avalanche/test_avalanche_learner.py | 2 +- .../avalanche/test_avalanche_model_updater.py | 2 + test/renate/updaters/experimental/test_er.py | 30 ++- .../updaters/experimental/test_fine_tuning.py | 14 +- .../updaters/experimental/test_joint.py | 4 +- .../experimental/test_repeated_distill.py | 32 ++-- test/renate/updaters/test_learner.py | 57 ++---- test/renate/updaters/test_model_updater.py | 24 ++- test/renate/utils/test_deepspeed.py | 26 +++ .../utils/test_distributed_strategies.py | 22 +++ test/renate/utils/test_misc.py | 19 ++ 53 files changed, 869 insertions(+), 448 deletions(-) create mode 100644 src/renate/utils/deepspeed.py create mode 100644 src/renate/utils/distributed_strategies.py create mode 100644 src/renate/utils/misc.py create mode 100644 test/renate/utils/test_deepspeed.py create mode 100644 test/renate/utils/test_distributed_strategies.py create mode 100644 test/renate/utils/test_misc.py diff --git a/doc/examples/nlp_finetuning.rst b/doc/examples/nlp_finetuning.rst index de86b324..311d409e 100644 --- a/doc/examples/nlp_finetuning.rst +++ b/doc/examples/nlp_finetuning.rst @@ -20,6 +20,9 @@ data module expects the name of a dataset as well as a tokenizer. Here, we load dataset in the first training stage (:code:`chunk_id = 0`) and the :code:`"rotten_tomatoes"` dataset for the subsequent model update (:code:`chunk_id = 1`). +The function :code:`loss_fn` defines the appropriate loss criterion. As this is a classification +problem we use :code:`torch.nn.CrossEntropyLoss`. + The data module will return pre-tokenized data and no further transforms are needed in this case. .. literalinclude:: ../../examples/nlp_finetuning/renate_config.py @@ -35,4 +38,36 @@ on this see previous examples or :doc:`../getting_started/how_to_run_training`. :lines: 3- +Support for training large models +--------------------------------- + +To support training methods for larger models, we expose two arguments in the +:code:`run_experiment_job` to enable training on multiple GPUs. For this we exploit the +strategy functionality provided by `Lightning` +`large model tutorial `_ and +`documentation `_. Currently, we +support +the strategies: + +* `"ddp_find_unused_parameters_false"` +* `"ddp"` +* `"deepspeed"` +* `"deepspeed_stage_1"` +* `"deepspeed_stage_2"` +* `"deepspeed_stage_2_offload"` +* `"deepspeed_stage_3"` +* `"deepspeed_stage_3_offload"` +* `"deepspeed_stage_3_offload_nvme"` + +These can be enabled by passing one of the above options to :code:`strategy`. The number of devices +to be used for parallel training can be specified using :code:`devices` argument which defaults to +`1`. We also support lower precision training by passing the :code:`precision` argument which +accepts the options `"16"`, `"32"`, `"64"`, `"bf16"`. Note that it has to be a string and not the +integer `32`. `bf16` is restricted to newer hardware and thus need slightly more attention before +using it. + +See last four lines in the previous code example. + +.. literalinclude:: ../../examples/nlp_finetuning/start.py + :lines: 47-49 diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index a7a08671..f7fc7929 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -49,6 +49,23 @@ method, but simply reinstantiate your model and call :code:`load_state_dict`. return model +Loss Definition +================ + +This function returns a :code:`torch.nn.Module` object that computes the loss with the +signature + +.. code-block:: python + + def loss_fn() -> torch.nn.Module: + +An example of this for the task of MNIST classfication above as + +.. literalinclude:: ../../examples/getting_started/renate_config.py + :caption: Loss function example + :lines: 95-96 + + Data Preparation ================ @@ -67,7 +84,7 @@ such as data subsampling or splitting. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 43-66 + :lines: 41-68 Transforms ========== @@ -112,7 +129,7 @@ These are optional as well but, if omitted, Renate will use :code:`train_transfo .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 73-90 + :lines: 71-78 Custom Metrics ============== @@ -124,7 +141,7 @@ or created ad-hoc by implementing the same interface .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 93- + :lines: 91- To enable the usage of additional metrics in Renate it is sufficient to implement the :code:`metrics_fn` function, returning a dictionary where the key is a string containing the diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index ddc2c8e8..1e931819 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -13,11 +13,9 @@ class MyMNISTMLP(RenateModule): def __init__(self, num_hidden: int) -> None: - # Model hyperparameters as well as the loss function need to registered via RenateModule's + # Model hyperparameters need to registered via RenateModule's # constructor, see documentation. Otherwise, this is a standard torch model. - super().__init__( - constructor_arguments={"num_hidden": num_hidden}, loss_fn=torch.nn.CrossEntropyLoss() - ) + super().__init__(constructor_arguments={"num_hidden": num_hidden}) self._fc1 = torch.nn.Linear(28 * 28, num_hidden) self._fc2 = torch.nn.Linear(num_hidden, 10) @@ -92,3 +90,7 @@ def buffer_transform() -> Callable: def metrics_fn() -> Dict: return {"my_accuracy": Accuracy()} + + +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() diff --git a/examples/nlp_finetuning/renate_config.py b/examples/nlp_finetuning/renate_config.py index 69055833..bcb40a3b 100644 --- a/examples/nlp_finetuning/renate_config.py +++ b/examples/nlp_finetuning/renate_config.py @@ -17,13 +17,17 @@ def model_fn(model_state_url: Optional[str] = None) -> RenateModule: transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=2, return_dict=False ) - model = RenateWrapper(transformer_model, loss_fn=torch.nn.CrossEntropyLoss()) + model = RenateWrapper(transformer_model) if model_state_url is not None: state_dict = torch.load(model_state_url) model.load_state_dict(state_dict) return model +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() + + def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule: """Returns one of two movie review datasets depending on `chunk_id`.""" tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased") diff --git a/examples/nlp_finetuning/start.py b/examples/nlp_finetuning/start.py index 7060ca42..40ae268f 100644 --- a/examples/nlp_finetuning/start.py +++ b/examples/nlp_finetuning/start.py @@ -44,4 +44,7 @@ instance_count=1, instance_type="ml.g4dn.xlarge", job_name="renate-training-nlp-finetuning", + devices=1, + strategy="deepspeed_stage_2", + precision="32", ) diff --git a/requirements.txt b/requirements.txt index c9ce9bc9..6da017dc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 torchmetrics~=0.10.3 torchvision>=0.13.0, <0.15.2 +deepspeed==0.9.1 datasets>=2.9.0, < 2.12.1 transformers>4.23.0, <4.29.3 scipy>=1.9.0, <1.10.2 diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 76693241..55acf107 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -172,17 +172,21 @@ def setup(self) -> None: self._data_path, train=True, transform=transforms.ToTensor(), - target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), + target_transform=transforms.Lambda(to_long), ) self._train_data, self._val_data = self._split_train_val_data(train_data) self._test_data = cls( self._data_path, train=False, transform=transforms.ToTensor(), - target_transform=transforms.Lambda(lambda x: torch.tensor(x, dtype=torch.long)), + target_transform=transforms.Lambda(to_long), ) +def to_long(x): + return torch.tensor(x, dtype=torch.long) + + class CLEARDataModule(RenateDataModule): """Datamodule that process CLEAR datasets: CLEAR10 and CLEAR100. diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 783d43fd..1dbb5b11 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -231,6 +231,10 @@ def get_scenario( raise ValueError(f"Unknown scenario `{scenario_name}`.") +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() + + def data_module_fn( data_path: str, chunk_id: int, diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 8c25290e..77bd0144 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -147,6 +147,8 @@ def execute_experiment_job( devices: int = defaults.DEVICES, deterministic_trainer: bool = True, job_name: str = defaults.JOB_NAME, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, ) -> None: """Executes the experiment job. @@ -202,6 +204,8 @@ def execute_experiment_job( devices=devices, deterministic_trainer=deterministic_trainer, seed=seed, + strategy=strategy, + precision=precision, ) _execute_experiment_job_remotely( job_name=job_name, @@ -226,6 +230,8 @@ def execute_experiment_job( instance_type=instance_type, instance_count=instance_count, instance_max_time=instance_max_time, + strategy=strategy, + precision=precision, ) @@ -246,6 +252,8 @@ def _execute_experiment_job_locally( max_num_trials_finished: int, n_workers: int, deterministic_trainer: bool, + strategy: str, + precision: str, ) -> None: """Runs an experiment, combining hyperparameter tuning and model for multiple updates. @@ -291,6 +299,9 @@ def _execute_experiment_job_locally( model_url, ) + # TODO: evaluate's trainer has to use devices=1: + # See https://github.com/Lightning-AI/lightning/issues/2537 + # The fix is to launch evaluation in a seperate process like training. results: Dict[str, List[List[float]]] = {} evaluate_and_record_results( results, @@ -301,7 +312,9 @@ def _execute_experiment_job_locally( logged_metrics=metrics, metric_postfix="_init", accelerator=accelerator, - devices=devices, + devices=1, + strategy=strategy, + precision=precision, ) for update_id in range(num_updates): @@ -329,6 +342,8 @@ def _execute_experiment_job_locally( seed=seed, accelerator=accelerator, devices=devices, + precision=precision, + strategy=strategy, deterministic_trainer=deterministic_trainer, ) move_to_uri(output_state_url, input_state_url) @@ -347,7 +362,7 @@ def _execute_experiment_job_locally( target_transform=transforms.get("target_test_transform"), logged_metrics=metrics, accelerator=accelerator, - devices=devices, + devices=1, ) df = individual_metrics_summary(results, update_id + 1, num_updates) save_pandas_df_to_csv( diff --git a/src/renate/benchmark/models/base.py b/src/renate/benchmark/models/base.py index b406795d..30a5424e 100644 --- a/src/renate/benchmark/models/base.py +++ b/src/renate/benchmark/models/base.py @@ -9,6 +9,7 @@ from renate import defaults from renate.models import RenateModule from renate.models.prediction_strategies import ICaRLClassificationStrategy, PredictionStrategy +from renate.utils.deepspeed import convert_to_tensor, recover_object_from_tensor # TODO: merge unit tests for the submodules @@ -23,7 +24,6 @@ class RenateBenchmarkingModule(RenateModule, ABC): embedding_size: Representation size of the model after the backbone. num_outputs: The number of outputs of the model. constructor_arguments: Arguments needed to instantiate the model. - loss_fn: The loss function to be optimized during the training. prediction_strategy: By default a forward pass through the model. Some ModelUpdater must be combined with specific prediction strategies to work as intended. add_icarl_class_means: Specific parameters for iCaRL. Can be set to ``False`` if any other @@ -35,7 +35,6 @@ def __init__( embedding_size: int, num_outputs: int, constructor_arguments: dict, - loss_fn: torch.nn.Module, prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ): @@ -43,7 +42,6 @@ def __init__( constructor_arguments["add_icarl_class_means"] = add_icarl_class_means super().__init__( constructor_arguments=constructor_arguments, - loss_fn=loss_fn, ) self._embedding_size = embedding_size self._num_outputs = num_outputs @@ -83,13 +81,23 @@ def get_params(self, task_id: str = defaults.TASK_ID) -> List[torch.nn.Parameter self.get_predictor(task_id=task_id).parameters() ) - def get_extra_state(self) -> Any: - """Get the constructor_arguments, loss and task ids necessary to reconstruct the model.""" - extra_state = super().get_extra_state() + def get_extra_state(self, encode=True) -> Any: + """Get the constructor_arguments and task ids necessary to reconstruct the model. + + Encode converts the state into a torch tensor so that Deepspeed serialization works. + We don't encode any of the super() calls, but encode only the final dict. + """ + extra_state = super().get_extra_state(encode=not encode) extra_state["prediction_strategy"] = self._prediction_strategy - return extra_state + return convert_to_tensor(extra_state) if encode else extra_state + + def set_extra_state(self, state: Any, decode=True): + """Extract the content of the ``_extra_state`` and set the related values in the module. - def set_extra_state(self, state: Any): - """Extract the content of the ``_extra_state`` and set the related values in the module.""" - super().set_extra_state(state) - self._prediction_strategy = state["prediction_strategy"] + decode flag is to decode the tensor of pkl bytes.""" + super().set_extra_state(state, decode=decode) + self._prediction_strategy = ( + recover_object_from_tensor(state)["prediction_strategy"] + if decode + else state["prediction_strategy"] + ) diff --git a/src/renate/benchmark/models/mlp.py b/src/renate/benchmark/models/mlp.py index bb38108f..6a0ae43f 100644 --- a/src/renate/benchmark/models/mlp.py +++ b/src/renate/benchmark/models/mlp.py @@ -18,7 +18,6 @@ class MultiLayerPerceptron(RenateBenchmarkingModule): num_hidden_layers: Number of hidden layers. hidden_size: Uniform hidden size or the list or tuple of hidden sizes for individual hidden layers. - loss: Loss function to be used for training. activation: Activation name, matching activation name in `torch.nn` to be used between the hidden layers. batch_normalization: Whether to use Batch Normalization after the activation. By default the @@ -35,7 +34,6 @@ def __init__( num_outputs: int, num_hidden_layers: int, hidden_size: Union[int, List[int], Tuple[int]], - loss: nn.Module = nn.CrossEntropyLoss(), activation: str = "ReLU", batch_normalization: bool = False, prediction_strategy: Optional[PredictionStrategy] = None, @@ -52,7 +50,6 @@ def __init__( "activation": activation, "batch_normalization": batch_normalization, }, - loss_fn=loss, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) diff --git a/src/renate/benchmark/models/resnet.py b/src/renate/benchmark/models/resnet.py index 3132c022..c56b1a84 100644 --- a/src/renate/benchmark/models/resnet.py +++ b/src/renate/benchmark/models/resnet.py @@ -29,7 +29,6 @@ class ResNet(RenateBenchmarkingModule): norm_layer: What kind of normalization layer to use, following convolutions. cifar_stem: Whether to use a stem for CIFAR-sized images. gray_scale: Whether input images are gray-scale images, i.e. only 1 color channel. - loss: Loss function to be used for training. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. add_icarl_class_means: If ``True``, additional parameters used only by the @@ -48,7 +47,6 @@ def __init__( norm_layer: Type[nn.Module] = nn.BatchNorm2d, cifar_stem: bool = True, gray_scale: bool = False, - loss: nn.Module = nn.CrossEntropyLoss(), prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ) -> None: @@ -76,7 +74,6 @@ def __init__( "cifar_stem": cifar_stem, "gray_scale": gray_scale, }, - loss_fn=loss, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) diff --git a/src/renate/benchmark/models/transformer.py b/src/renate/benchmark/models/transformer.py index fe8909d4..71ae33d3 100644 --- a/src/renate/benchmark/models/transformer.py +++ b/src/renate/benchmark/models/transformer.py @@ -3,7 +3,6 @@ from typing import Dict, Optional import torch -import torch.nn as nn from torch import Tensor from transformers import AutoModelForSequenceClassification @@ -16,21 +15,18 @@ class HuggingFaceSequenceClassificationTransformer(RenateModule): Args: pretrained_model_name: Hugging Face model id. num_outputs: Number of outputs. - loss_fn: The loss function to be optimized during the training. """ def __init__( self, pretrained_model_name: str, num_outputs: int, - loss_fn: nn.Module = nn.CrossEntropyLoss(), ) -> None: super().__init__( constructor_arguments={ "pretrained_model_name": pretrained_model_name, "num_outputs": num_outputs, }, - loss_fn=loss_fn, ) self._model = AutoModelForSequenceClassification.from_pretrained( pretrained_model_name, num_labels=num_outputs, return_dict=False diff --git a/src/renate/benchmark/models/vision_transformer.py b/src/renate/benchmark/models/vision_transformer.py index 928e90c9..6b16fc53 100644 --- a/src/renate/benchmark/models/vision_transformer.py +++ b/src/renate/benchmark/models/vision_transformer.py @@ -34,7 +34,6 @@ class VisionTransformer(RenateBenchmarkingModule): norm_layer: Normalization layer. conv_stem_configs: List of ConvStemConfig. Each ConvStemConfig corresponds to a convolutional stem. - loss: Loss function. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. add_icarl_class_means: If ``True``, additional parameters used only by the @@ -56,7 +55,6 @@ def __init__( norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), conv_stem_configs: Optional[List[ConvStemConfig]] = None, weights: Optional[WeightsEnum] = None, - loss: nn.Module = nn.CrossEntropyLoss(), prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ) -> None: @@ -90,7 +88,6 @@ def __init__( "norm_layer": norm_layer, "conv_stem_configs": conv_stem_configs, }, - loss_fn=loss, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 4a2cabaa..381168f5 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -4,6 +4,7 @@ import ast import inspect import sys +import pytorch_lightning as pl from importlib.util import find_spec from types import ModuleType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union @@ -286,6 +287,21 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: "help": f"Devices used for this job. Default: {defaults.DEVICES} device.", "argument_group": OPTIONAL_ARGS_GROUP, }, + "strategy": { + "type": str, + "default": defaults.DISTRIBUTED_STRATEGY, + "help": "Distributed training strategy when devices > 1. Default:" + + f"{defaults.DISTRIBUTED_STRATEGY}.", + "argument_group": OPTIONAL_ARGS_GROUP, + "choices": list(pl.strategies.StrategyRegistry.keys()), + }, + "precision": { + "type": str, + "default": defaults.PRECISION, + "help": f"Distributed training precision. Default: {defaults.PRECISION}.", + "argument_group": OPTIONAL_ARGS_GROUP, + "choices": ("16", "32", "64", "bf16"), + }, "early_stopping": { "type": str, "default": str(defaults.EARLY_STOPPING), diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index 9ba9484f..04e794e6 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -16,7 +16,13 @@ parse_arguments, ) from renate.utils.file import maybe_download_from_s3, move_to_uri -from renate.utils.module import get_and_setup_data_module, get_metrics, get_model, import_module +from renate.utils.module import ( + get_and_setup_data_module, + get_loss_fn, + get_metrics, + get_model, + import_module, +) from renate.utils.syne_tune import redirect_to_tmp logger = logging.getLogger(__name__) @@ -95,6 +101,7 @@ def run(self): "buffer_transform", "metrics_fn", "scheduler_fn", + "loss_fn", ], ignore_args=["data_path", "model_state_url"], ) @@ -113,6 +120,10 @@ def run(self): model_state_url=self._current_model_file, **get_function_kwargs(args=args, function_args=function_args["model_fn"]), ) + loss_fn = get_loss_fn( + config_module, + **get_function_kwargs(args=args, function_args=function_args["loss_fn"]), + ) metrics = get_metrics(config_module) @@ -128,8 +139,11 @@ def run(self): logged_metrics=metrics, accelerator=args.accelerator, devices=args.devices, + precision=args.precision, + strategy=args.strategy, early_stopping_enabled=args.early_stopping, deterministic_trainer=args.deterministic_trainer, + loss_fn=loss_fn, **learner_kwargs, **get_transforms_dict(config_module, args, function_args), ) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 54b38680..14b3c0fb 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -33,6 +33,8 @@ SUPPORTED_ACCELERATORS_TYPE = Literal["auto", "cpu", "gpu", "tpu"] DEVICES = 1 VOLUME_SIZE = 60 +DISTRIBUTED_STRATEGY = "ddp" +PRECISION = "32" LEARNER = "ER" INSTANCE_COUNT = 1 diff --git a/src/renate/evaluation/evaluator.py b/src/renate/evaluation/evaluator.py index ffd53d4e..743040a1 100644 --- a/src/renate/evaluation/evaluator.py +++ b/src/renate/evaluation/evaluator.py @@ -13,6 +13,8 @@ from renate.data.datasets import _TransformedDataset from renate.evaluation.metrics.utils import create_metrics from renate.models import RenateModule +from renate.utils.distributed_strategies import create_strategy +from renate.utils.misc import int_or_str class Evaluator(LightningModule, abc.ABC): @@ -114,6 +116,8 @@ def evaluate( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, ) -> Dict[str, List[float]]: """Evaluate the model on the test dataset or a set of test datasets corresponding to distinct tasks. @@ -159,6 +163,8 @@ def evaluate( logger=logger, enable_checkpointing=False, enable_progress_bar=False, + strategy=create_strategy(devices, strategy), + precision=int_or_str(precision), ) results = {} diff --git a/src/renate/models/renate_module.py b/src/renate/models/renate_module.py index e2c2317d..aa43e344 100644 --- a/src/renate/models/renate_module.py +++ b/src/renate/models/renate_module.py @@ -8,6 +8,7 @@ from renate.models.layers import ContinualNorm from renate.types import NestedTensors +from renate.utils.deepspeed import convert_to_tensor, recover_object_from_tensor class RenateModule(torch.nn.Module, ABC): @@ -20,7 +21,7 @@ class RenateModule(torch.nn.Module, ABC): in replay-based CL methods. When implementing a subclass of ``RenateModule``, make sure to call the base class' constructor - and provide your model's constructor arguments and loss function. Besides that, you can define a + and provide your model's constructor arguments. Besides that, you can define a ``RenateModule`` just like ``torch.nn.Module``. Example:: @@ -51,13 +52,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Args: constructor_arguments: Arguments needed to instantiate the model. - loss_fn: The loss function to be optimized during the training. """ - def __init__(self, constructor_arguments: dict, loss_fn: torch.nn.Module): + def __init__(self, constructor_arguments: dict): super(RenateModule, self).__init__() self._constructor_arguments = copy.deepcopy(constructor_arguments) - self.loss_fn = loss_fn self._tasks_params_ids: Set[str] = set() self._intermediate_representation_cache: List[torch.Tensor] = [] self._hooks: List[Callable] = [] @@ -70,26 +69,30 @@ def from_state_dict(cls, state_dict): state_dict: The state dict of the model. This method works under the assumption that this has been created by `RenateModule.state_dict()`. """ - constructor_arguments = state_dict["_extra_state"]["constructor_arguments"] + extra_state = recover_object_from_tensor(state_dict["_extra_state"]) + constructor_arguments = extra_state["constructor_arguments"] model = cls(**constructor_arguments) - for task in state_dict["_extra_state"]["tasks_params_ids"]: + for task in extra_state["tasks_params_ids"]: model.add_task_params(task) - model.load_state_dict(state_dict) + # TODO: See https://github.com/awslabs/Renate/issues/236. + # There are changes to the `class_means` or `componenets` of a model + # that are not loaded, and should probably not be stored. + model.load_state_dict(state_dict, strict=False) return model - def get_extra_state(self) -> Any: - """Get the constructor_arguments, loss and task ids necessary to reconstruct the model.""" - return { + def get_extra_state(self, encode: bool = True) -> Any: + """Get the constructor_arguments, and task ids necessary to reconstruct the model.""" + extra_state = { "constructor_arguments": self._constructor_arguments, "tasks_params_ids": self._tasks_params_ids, - "loss_fn": self.loss_fn, } + return convert_to_tensor(extra_state) if encode else extra_state - def set_extra_state(self, state: Any): + def set_extra_state(self, state: Any, decode: bool = True): """Extract the content of the ``_extra_state`` and set the related values in the module.""" - self._constructor_arguments = state["constructor_arguments"] - self._tasks_params_ids = state["tasks_params_ids"] - self.loss_fn = state["loss_fn"] + extra_state = recover_object_from_tensor(state) if decode else state + self._constructor_arguments = extra_state["constructor_arguments"] + self._tasks_params_ids = extra_state["tasks_params_ids"] @abstractmethod def forward(self, x: NestedTensors, task_id: Optional[str] = None) -> torch.Tensor: @@ -242,11 +245,10 @@ class RenateWrapper(RenateModule): Args: model: The torch model to be wrapped. - loss_fn: The loss function to be optimized during the training. """ - def __init__(self, model: torch.nn.Module, loss_fn: torch.nn.Module) -> None: - super().__init__(constructor_arguments={}, loss_fn=loss_fn) + def __init__(self, model: torch.nn.Module) -> None: + super().__init__(constructor_arguments={}) self._model = model def forward(self, x: NestedTensors, task_id: Optional[str] = None) -> torch.Tensor: diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 329c328e..3f08a5a4 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -90,6 +90,8 @@ def run_training_job( seed: int = defaults.SEED, accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: int = defaults.DEVICES, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, job_name: str = defaults.JOB_NAME, ) -> Optional[Tuner]: @@ -132,6 +134,8 @@ def run_training_job( seed: Seed used for ensuring reproducibility. accelerator: Type of accelerator to use. devices: Number of devices to use. + strategy: Name of the distributed training strategy to use. + precision: Type of bit precision to use. deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. job_name: Prefix for the name of the SageMaker training job. """ @@ -168,6 +172,8 @@ def run_training_job( seed=seed, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) submit_remote_job( @@ -199,6 +205,8 @@ def run_training_job( seed=seed, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, job_name=job_name, ) @@ -512,6 +520,8 @@ def _execute_training_and_tuning_job_locally( accelerator: str, devices: int, deterministic_trainer: bool, + strategy: str, + precision: str, ): """Executes the training job locally. @@ -529,6 +539,8 @@ def _execute_training_and_tuning_job_locally( config_space["seed"] = seed config_space["accelerator"] = accelerator config_space["devices"] = devices + config_space["strategy"] = strategy + config_space["precision"] = precision config_space["deterministic_trainer"] = deterministic_trainer if input_state_url is not None: config_space["input_state_url"] = input_state_url @@ -550,7 +562,9 @@ def _execute_training_and_tuning_job_locally( f"Tuning hyperparameters with respect to {metric} ({mode}) for {max_time} seconds on " f"{n_workers} worker(s)." ) - backend = LocalBackend(entry_point=training_script) + # TODO: After bumping up SyneTune >= 0.6, use the argument `num_gpus_per_trial`. + + backend = LocalBackend(entry_point=training_script, rotate_gpus=False if devices > 1 else True) if scheduler is None or not tune_hyperparameters: if scheduler is not None: warnings.warn( diff --git a/src/renate/updaters/avalanche/learner.py b/src/renate/updaters/avalanche/learner.py index d1c8a6ce..61da8fec 100644 --- a/src/renate/updaters/avalanche/learner.py +++ b/src/renate/updaters/avalanche/learner.py @@ -35,7 +35,7 @@ def update_settings( avalanche_learner.plugins = replace_plugin(plugin, avalanche_learner.plugins) avalanche_learner.model = self._model avalanche_learner.optimizer = optimizer - avalanche_learner._criterion = self._model.loss_fn + avalanche_learner._criterion = self._loss_fn avalanche_learner.train_epochs = max_epochs avalanche_learner.train_mb_size = self._batch_size avalanche_learner.eval_mb_size = self._batch_size @@ -55,7 +55,7 @@ def _create_avalanche_learner( return SupervisedTemplate( model=self._model, optimizer=optimizer, - criterion=self._model.loss_fn, + criterion=self._loss_fn, train_mb_size=self._batch_size, eval_mb_size=self._batch_size, train_epochs=train_epochs, diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index 7c4920cc..7a4f04f3 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -223,6 +223,7 @@ class ExperienceReplayAvalancheModelUpdater(AvalancheModelUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, @@ -248,6 +249,8 @@ def __init__( early_stopping_enabled: bool = False, accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -263,6 +266,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -283,6 +287,8 @@ def __init__( early_stopping_enabled=early_stopping_enabled, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, ) @@ -290,6 +296,7 @@ class ElasticWeightConsolidationModelUpdater(AvalancheModelUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, ewc_lambda: float = defaults.EWC_LAMBDA, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, @@ -314,6 +321,8 @@ def __init__( early_stopping_enabled: bool = False, accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -328,6 +337,7 @@ def __init__( "batch_size": batch_size, "ewc_lambda": ewc_lambda, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -348,6 +358,8 @@ def __init__( early_stopping_enabled=early_stopping_enabled, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, ) @@ -355,6 +367,7 @@ class LearningWithoutForgettingModelUpdater(AvalancheModelUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, alpha: float = defaults.LWF_ALPHA, temperature: float = defaults.LWF_TEMPERATURE, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, @@ -381,6 +394,8 @@ def __init__( accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, seed: int = defaults.SEED, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { @@ -395,6 +410,7 @@ def __init__( "alpha": alpha, "temperature": temperature, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -415,6 +431,8 @@ def __init__( early_stopping_enabled=early_stopping_enabled, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, ) @@ -422,6 +440,7 @@ class ICaRLModelUpdater(AvalancheModelUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, @@ -447,6 +466,8 @@ def __init__( early_stopping_enabled: bool = False, accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -462,6 +483,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -482,4 +504,6 @@ def __init__( early_stopping_enabled=early_stopping_enabled, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, ) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index cd27f458..99375860 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -57,15 +57,12 @@ def __init__( ) -> None: self._components_names = list(components.keys()) super().__init__(**kwargs) + self._memory_loader: Optional[DataLoader] = None self._components = components self._loss_weight = loss_weight self._ema_memory_update_gamma = ema_memory_update_gamma self._use_loss_normalization = bool(loss_normalization) - def _post_init(self) -> None: - super()._post_init() - self._memory_loader: Optional[DataLoader] = None - def _create_metrics_collections( self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None ) -> None: @@ -107,40 +104,6 @@ def on_train_start(self) -> None: for component in self._components.values(): component.on_train_start(model=self._model) - def state_dict(self, **kwargs) -> Dict[str, Any]: - """Returns the state of the learner.""" - state_dict = super().state_dict(**kwargs) - state_dict.update( - { - "loss_weight": self._loss_weight, - "ema_memory_update_gamma": self._ema_memory_update_gamma, - "loss_normalization": self._use_loss_normalization, - "components": self._components.state_dict(), - "components_names": self._components_names, - } - ) - return state_dict - - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner.""" - self._components_names = state_dict["components_names"] - super().load_state_dict(model, state_dict, **kwargs) - self._loss_weight = state_dict["loss_weight"] - self._ema_memory_update_gamma = state_dict["ema_memory_update_gamma"] - self._use_loss_normalization = state_dict["loss_normalization"] - self._components = self.components(model=model) - self._components.load_state_dict(state_dict["components"]) - - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - """Update the hyperparameters of the learner.""" - super().update_hyperparameters(args) - if "loss_weight" in args: - self._loss_weight = args["loss_weight"] - if "ema_memory_update_gamma" in args: - self._ema_memory_update_gamma = args["ema_memory_update_gamma"] - if "loss_normalization" in args: - self._use_loss_normalization = args["loss_normalization"] - def training_step( self, batch: Tuple[torch.Tensor, Tuple[NestedTensors, torch.Tensor]], batch_idx: int ) -> STEP_OUTPUT: @@ -254,25 +217,20 @@ class ExperienceReplayLearner(BaseExperienceReplayLearner): """ def __init__(self, alpha: float = defaults.ER_ALPHA, **kwargs) -> None: - components = self.components(model=kwargs["model"], alpha=alpha) + components = self.components(loss_fn=kwargs["loss_fn"], alpha=alpha) super().__init__(components=components, **kwargs) def components( - self, model: Optional[RenateModule] = None, alpha: float = defaults.ER_ALPHA + self, loss_fn: Optional[torch.nn.Module] = None, alpha: float = defaults.ER_ALPHA ) -> nn.ModuleDict: return nn.ModuleDict( { "memory_loss": WeightedCustomLossComponent( - loss_fn=model.loss_fn, weight=alpha, sample_new_memory_batch=True + loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True ) } ) - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - super().update_hyperparameters(args) - if "alpha" in args: - self._components["memory_loss"].set_weight(args["alpha"]) - class DarkExperienceReplayLearner(ExperienceReplayLearner): """A Learner that implements Dark Experience Replay. @@ -291,15 +249,15 @@ def __init__( self, alpha: float = defaults.DER_ALPHA, beta: float = defaults.DER_BETA, **kwargs ) -> None: super().__init__(alpha=beta, **kwargs) - self._components = self.components(model=kwargs["model"], alpha=alpha, beta=beta) + self._components = self.components(loss_fn=kwargs["loss_fn"], alpha=alpha, beta=beta) def components( self, - model: Optional[RenateModule] = None, + loss_fn: Optional[torch.nn.Module] = None, alpha: float = defaults.DER_ALPHA, beta: float = defaults.DER_BETA, ) -> nn.ModuleDict: - components = super().components(model=model, alpha=beta) + components = super().components(loss_fn=loss_fn, alpha=beta) components.update( { "mse_loss": WeightedMeanSquaredErrorLossComponent( @@ -309,13 +267,6 @@ def components( ) return components - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - super().update_hyperparameters(args) - if "alpha" in args: - self._components["mse_loss"].set_weight(args["alpha"]) - if "beta" in args: - self._components["memory_loss"].set_weight(args["beta"]) - class PooledOutputDistillationExperienceReplayLearner(BaseExperienceReplayLearner): """A Learner that implements Pooled Output Distillation. @@ -346,7 +297,6 @@ def __init__( def components( self, - model: Optional[RenateModule] = None, alpha: float = defaults.POD_ALPHA, distillation_type: str = defaults.POD_DISTILLATION_TYPE, normalize: bool = defaults.POD_NORMALIZE, @@ -362,17 +312,6 @@ def components( } ) - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - """Update the hyperparameters of the learner.""" - super().update_hyperparameters(args) - component = self._components["pod_loss"] - if "alpha" in args: - component.set_weight(args["alpha"]) - if "distillation_type" in args: - component.set_distillation_type(args["distillation_type"]) - if "normalize" in args: - component.set_normalize(args["normalize"]) - class CLSExperienceReplayLearner(BaseExperienceReplayLearner): """A learner that implements a Complementary Learning Systems Based Experience Replay. @@ -405,6 +344,7 @@ def __init__( ): components = self.components( model=kwargs["model"], + loss_fn=kwargs["loss_fn"], alpha=alpha, beta=beta, stable_model_update_weight=stable_model_update_weight, @@ -412,11 +352,13 @@ def __init__( stable_model_update_probability=stable_model_update_probability, plastic_model_update_probability=plastic_model_update_probability, ) + super().__init__(components=components, **kwargs) def components( self, model: RenateModule, + loss_fn: torch.nn.Module, alpha: float = defaults.CLS_ALPHA, beta: float = defaults.CLS_BETA, plastic_model_update_weight: float = defaults.CLS_PLASTIC_MODEL_UPDATE_WEIGHT, @@ -427,7 +369,7 @@ def components( return nn.ModuleDict( { "memory_loss": WeightedCustomLossComponent( - loss_fn=model.loss_fn, weight=alpha, sample_new_memory_batch=True + loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True ), "cls_loss": WeightedCLSLossComponent( weight=beta, @@ -441,28 +383,6 @@ def components( } ) - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - super().update_hyperparameters(args) - memory_loss_component = self._components["memory_loss"] - if "alpha" in args: - memory_loss_component.set_weight(args["alpha"]) - - cls_component = self._components["cls_loss"] - if "beta" in args: - cls_component.set_weight(args["beta"]) - if "stable_model_update_weight" in args: - cls_component.set_stable_model_update_weight(args["stable_model_update_weight"]) - if "plastic_model_update_weight" in args: - cls_component.set_plastic_model_update_weight(args["plastic_model_update_weight"]) - if "stable_model_update_probability" in args: - cls_component.set_stable_model_update_probability( - args["stable_model_update_probability"] - ) - if "plastic_model_update_probability" in args: - cls_component.set_plastic_model_update_probability( - args["plastic_model_update_probability"] - ) - class SuperExperienceReplayLearner(BaseExperienceReplayLearner): """A learner that implements a selected combination of methods. @@ -510,6 +430,7 @@ def __init__( ) -> None: components = self.components( model=kwargs["model"], + loss_fn=kwargs["loss_fn"], der_alpha=der_alpha, der_beta=der_beta, sp_shrink_factor=sp_shrink_factor, @@ -532,6 +453,7 @@ def __init__( def components( self, model: RenateModule, + loss_fn: torch.nn.Module, der_alpha: float = defaults.SER_DER_ALPHA, der_beta: float = defaults.SER_DER_BETA, sp_shrink_factor: float = defaults.SER_SP_SHRINK_FACTOR, @@ -551,7 +473,7 @@ def components( weight=der_alpha, sample_new_memory_batch=True ), "memory_loss": WeightedCustomLossComponent( - loss_fn=model.loss_fn, weight=der_beta, sample_new_memory_batch=True + loss_fn=loss_fn, weight=der_beta, sample_new_memory_batch=True ), "cls_loss": WeightedCLSLossComponent( weight=cls_alpha, @@ -574,46 +496,12 @@ def components( } ) - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - super().update_hyperparameters(args) - if "der_alpha" in args: - self._components["mse_loss"].set_weight(args["der_alpha"]) - if "der_beta" in args: - self._components["memory_loss"].set_weight(args["der_beta"]) - if "sp_mu" in args: - self._components["shrink_perturb"].set_shrink_factor(args["sp_mu"]) - if "sp_sigma" in args: - self._components["shrink_perturb"].set_sigma(args["sp_sigma"]) - if "pod_alpha" in args: - self._components["pod_loss"].set_weight(args["pod_alpha"]) - if "pod_distillation_type" in args: - self._components["pod_loss"].set_distillation_type(args["pod_distillation_type"]) - if "pod_normalize" in args: - self._components["pod_loss"].set_normalize(args["pod_normalize"]) - if "cls_alpha" in args: - self._components["cls_loss"].set_weight(args["cls_alpha"]) - if "cls_stable_model_update_weight" in args: - self._components["cls_loss"].set_stable_model_update_weight( - args["cls_stable_model_update_weight"] - ) - if "cls_plastic_model_update_weight" in args: - self._components["cls_loss"].set_plastic_model_update_weight( - args["cls_plastic_model_update_weight"] - ) - if "cls_stable_model_update_probability" in args: - self._components["cls_loss"].set_stable_model_update_probability( - args["cls_stable_model_update_probability"] - ) - if "cls_plastic_model_update_probability" in args: - self._components["cls_loss"].set_plastic_model_update_probability( - args["cls_plastic_model_update_probability"] - ) - class ExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -644,6 +532,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -663,6 +553,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -684,6 +575,8 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) @@ -692,6 +585,7 @@ class DarkExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -723,6 +617,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -743,6 +639,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -764,6 +661,8 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) @@ -772,6 +671,7 @@ class PooledOutputDistillationExperienceReplayModelUpdater(SingleTrainingLoopUpd def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -804,6 +704,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -825,6 +727,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -846,6 +749,8 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) @@ -854,6 +759,7 @@ class CLSExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -889,6 +795,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -913,6 +821,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -934,6 +843,8 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) @@ -942,6 +853,7 @@ class SuperExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -983,6 +895,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -1013,6 +927,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -1034,5 +949,7 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) diff --git a/src/renate/updaters/experimental/fine_tuning.py b/src/renate/updaters/experimental/fine_tuning.py index 0f166600..1a0db866 100644 --- a/src/renate/updaters/experimental/fine_tuning.py +++ b/src/renate/updaters/experimental/fine_tuning.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Callable, Dict, Optional +import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger @@ -15,6 +16,7 @@ class FineTuningModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 @@ -37,6 +39,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -50,6 +54,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -70,4 +75,6 @@ def __init__( accelerator=accelerator, devices=devices, deterministic_trainer=deterministic_trainer, + strategy=strategy, + precision=precision, ) diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index d0c87228..3d1bf239 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -12,7 +12,7 @@ from renate.memory import GreedyClassBalancingBuffer from renate.models import RenateModule from renate.types import NestedTensors -from renate.updaters.learner import Learner, ReplayLearner +from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater from renate.utils.pytorch import reinitialize_model_parameters @@ -47,6 +47,7 @@ def __init__( seed=seed, **kwargs, ) + self._memory_buffer = GreedyClassBalancingBuffer( max_size=memory_size, seed=seed, @@ -54,13 +55,9 @@ def __init__( target_transform=buffer_target_transform, ) - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner.""" - if not hasattr(self, "_memory_buffer"): - self._memory_buffer = GreedyClassBalancingBuffer() - Learner.load_state_dict(self, model, state_dict, **kwargs) - self._memory_batch_size = state_dict["memory_batch_size"] - self._memory_buffer.load_state_dict(state_dict["memory_buffer"]) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + self._memory_buffer.load_state_dict(checkpoint["memory_buffer"]) def on_model_update_start( self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None @@ -93,6 +90,7 @@ class GDumbModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, @@ -119,6 +117,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -134,6 +134,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -155,5 +156,7 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index 90d4441e..0bcb9677 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -34,18 +34,13 @@ def __init__(self, **kwargs: Any) -> None: target_transform=self._train_target_transform, ) - def state_dict(self, **kwargs) -> Dict[str, Any]: - """Returns the state of the learner.""" - state_dict = super().state_dict(**kwargs) - state_dict["memory_buffer"] = self._memory_buffer.state_dict() - return state_dict + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_save_checkpoint(checkpoint=checkpoint) + checkpoint["memory_buffer"] = self._memory_buffer.state_dict() - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner.""" - if not hasattr(self, "_memory_buffer"): - self._memory_buffer = InfiniteBuffer() - super().load_state_dict(model, state_dict, **kwargs) - self._memory_buffer.load_state_dict(state_dict["memory_buffer"]) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + self._memory_buffer.load_state_dict(checkpoint["memory_buffer"]) def save(self, output_state_dir: str) -> None: super().save(output_state_dir) @@ -57,19 +52,6 @@ def load(self, input_state_dir: str) -> None: super().load(input_state_dir) self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer")) - def set_transforms( - self, - train_transform: Optional[Callable] = None, - train_target_transform: Optional[Callable] = None, - test_transform: Optional[Callable] = None, - test_target_transform: Optional[Callable] = None, - ) -> None: - """Update the transformations applied to the data.""" - super().set_transforms( - train_transform, train_target_transform, test_transform, test_target_transform - ) - self._memory_buffer.set_transforms(train_transform, train_target_transform) - def on_model_update_start( self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None ) -> None: @@ -101,6 +83,7 @@ class JointModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 @@ -123,6 +106,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -136,6 +121,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -155,5 +141,7 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 4889b2ce..dc2779e1 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -97,35 +97,31 @@ def training_step( alpha = self._loss_weight_new_data inputs, targets = batch["current_task"] outputs = self(inputs) - loss = self._model.loss_fn(outputs, targets) + loss = self._loss_fn(outputs, targets) self._loss_collections["train_losses"]["base_loss"](loss) self._update_metrics(outputs, targets, "train") if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] outputs_mem = self(inputs_mem) - loss_mem = self._model.loss_fn(outputs_mem, targets_mem) + loss_mem = self._loss_fn(outputs_mem, targets_mem) self._loss_collections["train_losses"]["memory_loss"](loss_mem) loss = alpha * loss + (1.0 - alpha) * loss_mem return {"loss": loss} - def state_dict(self, **kwargs) -> Dict[str, Any]: - """Returns the state of the learner.""" - state_dict = super().state_dict(**kwargs) - state_dict["loss_weight_new_data"] = self._loss_weight_new_data - state_dict["num_points_previous_tasks"] = self._num_points_previous_tasks - return state_dict + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_save_checkpoint(checkpoint) + checkpoint["num_points_previous_tasks"] = self._num_points_previous_tasks - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner.""" - super().load_state_dict(model, state_dict, **kwargs) - self._loss_weight_new_data = state_dict["loss_weight_new_data"] - self._num_points_previous_tasks = state_dict["num_points_previous_tasks"] + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + self._num_points_previous_tasks = checkpoint["num_points_previous_tasks"] class OfflineExperienceReplayModelUpdater(SingleTrainingLoopUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight_new_data: Optional[float] = None, @@ -153,6 +149,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): @@ -169,6 +167,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model, @@ -190,5 +189,7 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, deterministic_trainer=deterministic_trainer, ) diff --git a/src/renate/updaters/experimental/repeated_distill.py b/src/renate/updaters/experimental/repeated_distill.py index e10db492..0ee27b88 100644 --- a/src/renate/updaters/experimental/repeated_distill.py +++ b/src/renate/updaters/experimental/repeated_distill.py @@ -94,6 +94,7 @@ class RepeatedDistillationModelUpdater(ModelUpdater): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, optimizer: str = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, @@ -117,6 +118,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, seed: Optional[int] = None, early_stopping_enabled=False, @@ -133,6 +136,7 @@ def __init__( "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, + "loss_fn": loss_fn, } super().__init__( model=model, @@ -152,6 +156,8 @@ def __init__( logger=logger, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, early_stopping_enabled=early_stopping_enabled, logged_metrics=logged_metrics, deterministic_trainer=deterministic_trainer, @@ -208,6 +214,7 @@ class RepeatedDistillationLearner(ReplayLearner): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, memory_size: int, optimizer: str = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, @@ -246,6 +253,7 @@ def __init__( test_target_transform=test_target_transform, logged_metrics=logged_metrics, seed=seed, + loss_fn=loss_fn, ) self._expert_logits: Optional[torch.Tensor] = None diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index 56a09889..7a1c346b 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -57,6 +57,7 @@ class Learner(LightningModule, abc.ABC): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, learning_rate: float = defaults.LEARNING_RATE, learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 @@ -74,6 +75,7 @@ def __init__( ) -> None: super().__init__() self._model = model + self._loss_fn = loss_fn self._optimizer = optimizer self._learning_rate = learning_rate self._learning_rate_scheduler = learning_rate_scheduler @@ -91,10 +93,24 @@ def __init__( self._val_memory_buffer: DataBuffer = InfiniteBuffer() self._create_metrics_collections(logged_metrics) - self._post_init() - - def _post_init(self) -> None: self._rng = get_generator(self._seed) + self.save_hyperparameters( + ignore=[ + "model", + "loss_fn", + "components", + "train_transform", + "test_transform", + "buffer_transform", + "train_transform", + "train_target_transform", + "test_transform", + "test_target_transform", + "buffer_transform", + "buffer_target_transform", + "logged_metrics", + ] + ) def _create_metrics_collections( self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None @@ -125,53 +141,15 @@ def _create_metrics_collections( } ) - def state_dict(self, **kwargs) -> Dict[str, Any]: - """Returns the state of the learner.""" - return { + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + learner_state_dict = { "learner_class_name": self.__class__.__name__, - "optimizer": self._optimizer, - "learning_rate": self._learning_rate, - "learning_rate_scheduler": self._learning_rate_scheduler, - "learning_rate_scheduler_gamma": self._learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": self._learning_rate_scheduler_step_size, - "momentum": self._momentum, - "weight_decay": self._weight_decay, - "batch_size": self._batch_size, - "seed": self._seed, - "task_id": self._task_id, "val_memory_buffer": self._val_memory_buffer.state_dict(), } + checkpoint.update(learner_state_dict) - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner. - - Even though this is a LightningModule, no modules are stored. - - Args: - model: The model to be trained. - state_dict: Dictionary containing the state. - """ - if self.__class__.__name__ != state_dict["learner_class_name"]: - raise RuntimeError( - f"Learner of class {self.__class__} was used to load a state dict created by class " - f"{state_dict['learner_class_name']}." - ) - super().__init__() - self._model = model - self._optimizer = state_dict["optimizer"] - self._learning_rate = state_dict["learning_rate"] - self._learning_rate_scheduler = state_dict["learning_rate_scheduler"] - self._learning_rate_scheduler_gamma = state_dict["learning_rate_scheduler_gamma"] - self._learning_rate_scheduler_step_size = state_dict["learning_rate_scheduler_step_size"] - self._momentum = state_dict["momentum"] - self._weight_decay = state_dict["weight_decay"] - self._batch_size = state_dict["batch_size"] - self._seed = state_dict["seed"] - self._task_id = state_dict["task_id"] - if not hasattr(self, "_val_memory_buffer"): - self._val_memory_buffer = InfiniteBuffer() - self._val_memory_buffer.load_state_dict(state_dict["val_memory_buffer"]) - self._post_init() + def on_load_checkpoint(self, checkpoint: Dict[str, Any]): + self._val_memory_buffer.load_state_dict(checkpoint["val_memory_buffer"]) def save(self, output_state_dir: str) -> None: val_buffer_dir = os.path.join(output_state_dir, "val_memory_buffer") @@ -181,25 +159,6 @@ def save(self, output_state_dir: str) -> None: def load(self, input_state_dir: str) -> None: self._val_memory_buffer.load(os.path.join(input_state_dir, "val_memory_buffer")) - def set_transforms( - self, - train_transform: Optional[Callable] = None, - train_target_transform: Optional[Callable] = None, - test_transform: Optional[Callable] = None, - test_target_transform: Optional[Callable] = None, - ) -> None: - """Update the transformations applied to the data.""" - self._train_transform = train_transform - self._train_target_transform = train_target_transform - self._test_transform = test_transform - self._test_target_transform = test_target_transform - - def set_logged_metrics( - self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None - ) -> None: - """Sets the additional metrics logged during training and evaluation.""" - self._create_metrics_collections(logged_metrics) - def is_logged_metric(self, metric_name: str) -> bool: """Returns `True` if there is a metric with name `metric_name`.""" if metric_name is None: @@ -217,25 +176,6 @@ def is_logged_metric(self, metric_name: str) -> bool: ] return metric_name in logged_metrics - def update_hyperparameters(self, args: Dict[str, Any]) -> None: - """Update the hyperparameters of the learner.""" - if "optimizer" in args: - self._optimizer = args["optimizer"] - if "learning_rate" in args: - self._learning_rate = args["learning_rate"] - if "learning_rate_scheduler" in args: - self._learning_rate_scheduler = args["learning_rate_scheduler"] - if "learning_rate_scheduler_gamma" in args: - self._learning_rate_scheduler_gamma = args["learning_rate_scheduler_gamma"] - if "learning_rate_scheduler_step_size" in args: - self._learning_rate_scheduler_step_size = args["learning_rate_scheduler_step_size"] - if "momentum" in args: - self._momentum = args["momentum"] - if "weight_decay" in args: - self._weight_decay = args["weight_decay"] - if "batch_size" in args: - self._batch_size = args["batch_size"] - def on_model_update_start( self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None ) -> None: @@ -298,7 +238,7 @@ def training_step( outputs = self(inputs) intermediate_representation = self._model.get_intermediate_representation() self._model.reset_intermediate_representation_cache() - loss = self._model.loss_fn(outputs, targets) + loss = self._loss_fn(outputs, targets) self._update_metrics(outputs, targets, "train") self._loss_collections["train_losses"]["base_loss"](loss) return { @@ -323,7 +263,7 @@ def validation_step(self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: """PyTorch Lightning function to estimate validation metrics.""" (inputs, targets), _ = batch outputs = self(inputs) - loss = self._model.loss_fn(outputs, targets) + loss = self._loss_fn(outputs, targets) self._update_metrics(outputs, targets, "val") self._loss_collections["val_losses"]["loss"](loss) @@ -373,6 +313,7 @@ def _log_metrics( on_step=False, on_epoch=True, logger=True, + sync_dist=True, ) self._metric_collections[f"{prefix}_metrics"].reset() @@ -383,6 +324,7 @@ def _log_metrics( on_step=False, on_epoch=True, logger=True, + sync_dist=True, ) loss.reset() @@ -419,24 +361,13 @@ def __init__( target_transform=buffer_target_transform, ) - def state_dict(self, **kwargs) -> Dict[str, Any]: - """Returns the state of the learner.""" - state_dict = super().state_dict(**kwargs) - state_dict.update( - { - "memory_batch_size": self._memory_batch_size, - "memory_buffer": self._memory_buffer.state_dict(), - } - ) - return state_dict + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_save_checkpoint(checkpoint) + checkpoint["memory_buffer"] = self._memory_buffer.state_dict() - def load_state_dict(self, model: RenateModule, state_dict: Dict[str, Any], **kwargs) -> None: - """Restores the state of the learner.""" - super().load_state_dict(model, state_dict, **kwargs) - self._memory_batch_size = state_dict["memory_batch_size"] - if not hasattr(self, "_memory_buffer"): - self._memory_buffer = ReservoirBuffer() - self._memory_buffer.load_state_dict(state_dict["memory_buffer"]) + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_load_checkpoint(checkpoint) + self._memory_buffer.load_state_dict(checkpoint["memory_buffer"]) def save(self, output_state_dir: str) -> None: super().save(output_state_dir) @@ -447,21 +378,3 @@ def save(self, output_state_dir: str) -> None: def load(self, input_state_dir: str) -> None: super().load(input_state_dir) self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer")) - - def set_transforms( - self, - train_transform: Optional[Callable] = None, - train_target_transform: Optional[Callable] = None, - test_transform: Optional[Callable] = None, - test_target_transform: Optional[Callable] = None, - buffer_transform: Optional[Callable] = None, - buffer_target_transform: Optional[Callable] = None, - ) -> None: - """Update the transformations applied to the data.""" - super().set_transforms( - train_transform=train_transform, - train_target_transform=train_target_transform, - test_transform=test_transform, - test_target_transform=test_target_transform, - ) - self._memory_buffer.set_transforms(buffer_transform, buffer_target_transform) diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index fb78aea1..e086cc7a 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -11,12 +11,18 @@ from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.utilities.rank_zero import rank_zero_only from syne_tune import Reporter from torch.utils.data import Dataset from renate import defaults -from .learner import Learner, ReplayLearner +from renate.utils.deepspeed import convert_zero_checkpoint_to_fp32_state_dict +from renate.utils.distributed_strategies import create_strategy +from renate.utils.file import unlink_file_or_folder +from renate.utils.misc import int_or_str + from ..models import RenateModule +from .learner import Learner, ReplayLearner logging_logger = logging.getLogger(__name__) @@ -33,6 +39,7 @@ def __init__(self, val_enabled: bool): self._report = Reporter() self._val_enabled = val_enabled + @rank_zero_only def _log(self, trainer: Trainer, training: bool) -> None: """Report the current epoch's results to Syne Tune. @@ -96,7 +103,7 @@ def __init__( self._output_state_folder = output_state_folder self.CHECKPOINT_NAME_LAST = learner_checkpoint_filename # Delete old checkpoint if exists - Path(defaults.learner_state_file(self._output_state_folder)).unlink(missing_ok=True) + unlink_file_or_folder(Path(defaults.learner_state_file(self._output_state_folder))) # FIXME: Hack to make sure Syne Tune is called after checkpointing. # Details: https://github.com/Lightning-AI/lightning/issues/15026 # If fixed, remove on_train_epoch_end, on_validation_epoch_end, val_enabled, remove line @@ -106,11 +113,6 @@ def __init__( else: self._syne_tune_callback = None - def _save_checkpoint(self, trainer: Trainer, filepath: str) -> None: - Path(self.dirpath).mkdir(parents=True, exist_ok=True) - torch.save(self._model.state_dict(), defaults.model_file(self.dirpath)) - super()._save_checkpoint(trainer=trainer, filepath=filepath) - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_train_epoch_end(trainer=trainer, pl_module=pl_module) if self._syne_tune_callback is not None: @@ -125,8 +127,17 @@ def _load_best_checkpoint_and_save(self, trainer: Trainer, pl_module: LightningM # Reload best state. learner_state_path = Path(defaults.learner_state_file(self._output_state_folder)) if learner_state_path.exists(): - self._model.load_state_dict(torch.load(defaults.model_file(self.dirpath))) - pl_module.load_state_dict(self._model, torch.load(learner_state_path)["state_dict"]) + # There are three obvious steps that are handled by lightning if + # we use the load_from_checkpoint mechanism. Here we do those manually. + # See for reference + # https://github.com/Lightning-AI/lightning/blob/1.8.6/src/pytorch_lightning/core/saving.py#L225 + # 1. Load the state_dict from the checkpoint file. + # 2. Call the on_load_checkpoint (which is a callback) + # 3. Load the state_dict into the model. Note the strategy.load_model_state_dict call. + loaded_state = trainer.strategy.load_checkpoint(learner_state_path) + pl_module.on_load_checkpoint(loaded_state) + trainer.strategy.load_model_state_dict(loaded_state) + # Finalize model update. pl_module.on_model_update_end() # Save permanently. @@ -134,6 +145,49 @@ def _load_best_checkpoint_and_save(self, trainer: Trainer, pl_module: LightningM # Overwrite checkpoint. self._save_checkpoint(trainer, learner_state_path) + def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + """ + teardown implements the separation of learner and model at the end of training. + + There are two cases two handle. + + 1. If deepspeed is being used: + The learner_state_path (which the checkpointing func) uses is a directory and not a file. + This directory has sharded state_dicts (of model and optimizers, depending on which + deepspeed stage is used. There are three steps here + + a. combine all the shards into one big state dict. + b. The learner_state_path is a dir (learner.cpkt/). This needs to be deleted first. + c. Write the combined state_dict as the learner.cpkt file as a single file. + d. Extract the state_dict element from the learner and save that as the model.cpkt. + + 2. If not deepspeed (say DDP or single device): + The steps are much simpler. + + a. Load the learner.cpkt and extract the state_dict element. + b. | Sanitize the extracted state_dict. Learner has the model in a _model attribute. + | So strip the first "_model." from the keys of the state_dict. + c. Save the sanitized model to model.cpkt. + + Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in + Case 1 which automatically goes to Case 2 as learner.cpkt is file now. + """ + + if trainer.is_global_zero and (stage == "fit"): + learner_state_path = Path(defaults.learner_state_file(self._output_state_folder)) + if learner_state_path.exists() and learner_state_path.is_dir(): + # Deepspeed zero saves everything as folders. + combined_state_dict = convert_zero_checkpoint_to_fp32_state_dict(learner_state_path) + unlink_file_or_folder(learner_state_path) + torch.save(combined_state_dict, learner_state_path) + self.teardown(trainer, pl_module, stage) + elif learner_state_path.exists() and learner_state_path.is_file(): + ## This a normal file. We strip the model of any wrappers and save that. + state_dict = torch.load(learner_state_path)["state_dict"] + out_sd = {k.replace("_model.", "", 1): v for k, v in state_dict.items()} + # Replace only 1 instance because we have to load it into RenateModule. + torch.save(out_sd, defaults.model_file(self.dirpath)) + def on_exception( self, trainer: Trainer, pl_module: LightningModule, exception: BaseException ) -> None: @@ -201,6 +255,8 @@ def __init__( logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): self._learner_kwargs = learner_kwargs or {} @@ -256,6 +312,8 @@ def __init__( ) self._accelerator = accelerator self._devices = devices + self._strategy = strategy + self._precision = int_or_str(precision) self._learner = self._load_learner(learner_class, self._learner_kwargs) assert self._learner.is_logged_metric(metric), f"Target metric `{metric}` is not logged." self._logger = logger @@ -290,12 +348,14 @@ def _load_learner( logged_metrics=self._logged_metrics, **self._transforms_kwargs, ) - learner = learner_class.__new__(learner_class) - learner.load_state_dict(self._model, torch.load(self._learner_state_file)["state_dict"]) + learner = learner_class.load_from_checkpoint( + self._learner_state_file, + model=self._model, + logged_metrics=self._logged_metrics, + **self._transforms_kwargs, + **learner_kwargs, + ) learner.load(self._input_state_folder) - learner.set_transforms(**self._transforms_kwargs) - learner.set_logged_metrics(self._logged_metrics) - learner.update_hyperparameters(learner_kwargs) return learner def _fit_learner( @@ -326,6 +386,7 @@ def _fit_learner( "be ignored." ) + strategy = create_strategy(self._devices, self._strategy) trainer = Trainer( accelerator=self._accelerator, devices=self._devices, @@ -334,6 +395,8 @@ def _fit_learner( logger=self._logger, enable_progress_bar=False, deterministic=self._deterministic_trainer, + strategy=strategy, + precision=self._precision, ) trainer.fit(learner) self._num_epochs_trained = trainer.current_epoch diff --git a/src/renate/utils/deepspeed.py b/src/renate/utils/deepspeed.py new file mode 100644 index 00000000..ed762d09 --- /dev/null +++ b/src/renate/utils/deepspeed.py @@ -0,0 +1,112 @@ +# Copyright 2020 The PyTorch Lightning team and Microsoft Corporation. All rights reserved. +# Modifications: Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved +# SPDX-License-Identifier: Apache-2.0 + +import os +import pickle as pkl +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import torch +from deepspeed.utils.zero_to_fp32 import ( + get_fp32_state_dict_from_zero_checkpoint, + get_model_state_file, + get_optim_files, +) + +CPU_DEVICE = torch.device("cpu") + + +def ds_checkpoint_dir(checkpoint_dir: Union[str, Path], tag: Optional[str] = None) -> str: + if tag is None: + latest_path = os.path.join(checkpoint_dir, "latest") + if os.path.isfile(latest_path): + with open(latest_path) as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + directory = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(directory): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + return directory + + +# modified +def search_key(state: Dict[str, Any], substring: str) -> str: + """This function looks for a substring in keys of dict and returns the full key that + is the first match.""" + for k in state.keys(): + if substring in k: + return k + + +# Modified script from +# https://github.com/Lightning-AI/lightning/blob/1.8.6/src/pytorch_lightning/utilities/deepspeed.py +# which is modified from +# https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py +def convert_zero_checkpoint_to_fp32_state_dict( + checkpoint_dir: Union[str, Path], tag: Optional[str] = None +) -> Dict[str, Any]: + """Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file + that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training + without DeepSpeed. Additionally the script has been modified to ensure we keep the + lightning state inside the state dict for being able to run + ``LightningModule.load_from_checkpoint('...')``. + Modification to this version include the explicit handling of the _extra_state + element of state dict. Deepspeed's and Lightning get-fp-32... functions only collate + trainable parameters. + + Args: + checkpoint_dir: path to the desired checkpoint folder. + (one that contains the tag-folder, like ``global_step14``) + tag: checkpoint tag used as a unique identifier for checkpoint. If not provided will + attempt to load tag in the file named ``latest`` in the checkpoint folder, + e.g., ``global_step14``. + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + # additional logic to ensure we keep the lightning state dict as well from rank 0. + deepspeed_states = [ + "module", + "optimizer", + "lr_scheduler", + "csr_tensor_module_names", + "skipped_steps", + "global_steps", + "dp_world_size", + "mp_world_size", + ] + checkpoint_dir = ds_checkpoint_dir(checkpoint_dir) + optim_files = get_optim_files(checkpoint_dir) + optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE) + zero_stage = optim_state["optimizer_state_dict"]["zero_stage"] + model_file = get_model_state_file(checkpoint_dir, zero_stage) + client_state = torch.load(model_file, map_location=CPU_DEVICE) + # Assign extra_state by searching for which key it is + extra_key = search_key(client_state["module"], "extra_state") + extra_state = client_state["module"][extra_key] + state_dict[extra_key] = extra_state + ## End of modifications + client_state = { + key: value for key, value in client_state.items() if key not in deepspeed_states + } + # State dict keys will include reference to wrapper _LightningModuleWrapperBase + # Delete `module` prefix before saving. + state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()} + client_state["state_dict"] = state_dict + + return client_state + + +def convert_to_tensor(obj): + """This function converts a pickleable object to a torch tensor. This is only to + aid saving with Deepspeed.""" + return torch.as_tensor(list(pkl.dumps(obj))) + + +def recover_object_from_tensor(tensor): + """This function converts a tensor to a byte stream that is passed through pickle + to recover the underlying object. For usage with Deepspeed""" + return pkl.loads(bytes(tensor.tolist())) diff --git a/src/renate/utils/distributed_strategies.py b/src/renate/utils/distributed_strategies.py new file mode 100644 index 00000000..28fe5082 --- /dev/null +++ b/src/renate/utils/distributed_strategies.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import warnings +from typing import Optional +from pytorch_lightning.strategies import Strategy, StrategyRegistry + + +_SUPPORTED_STRATEGIES = [ + "ddp_find_unused_parameters_false", + "ddp", + "deepspeed", + "deepspeed_stage_1", + "deepspeed_stage_2", + "deepspeed_stage_2_offload", + "deepspeed_stage_3", + "deepspeed_stage_3_offload", + "deepspeed_stage_3_offload_nvme", +] +_UNSUPPORTED_STRATEGIES = [ + x for x in StrategyRegistry.available_strategies() if x not in _SUPPORTED_STRATEGIES +] + + +def create_strategy(devices: int = 1, strategy_name: Optional["str"] = None) -> Strategy: + """Function returns a strategy object based on the number of devices queried + and name of strategy""" + + devices = devices or 1 + if strategy_name in _UNSUPPORTED_STRATEGIES: + raise ValueError( + f"Current strategy: {strategy_name} is unsupported. Choose deepspeed variants or ddp." + ) + if devices < 0: + raise ValueError("Number of devices has to be at least 0.") + + elif devices == 1: + # If one GPU, use standard training. Enabled by passing strategy=None + # to pl.Trainer + if strategy_name is not None: + warnings.warn(f"With devices=1, strategy is ignored. But got {strategy_name}.") + + return None + elif strategy_name in ["none", "None", None]: + ## Nothing is specified and devices > 1. Fall back to DDP + return StrategyRegistry.get("ddp") + + elif "deepspeed" in strategy_name: + strategy = StrategyRegistry.get(strategy_name) + + ## TODO: This should be changed to instantiating Deepspeed and settting it in + # the constructor. This works for nowbecause forcing PyTorch optimizer flag isn't used + # anywhere by Deepspeed. + strategy.config["zero_force_ds_cpu_optimizer"] = False + return strategy + + else: + # Something else happened. Fall back to whatever is happening. + return StrategyRegistry.get(strategy_name) diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index 6aaca71b..9c19b43b 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -12,6 +12,7 @@ import pandas as pd import requests from botocore.exceptions import ClientError +from pytorch_lightning.utilities.rank_zero import rank_zero_only logger = logging.getLogger(__name__) @@ -277,3 +278,17 @@ def save_pandas_df_to_csv(df: pd.DataFrame, file_path: Union[str, Path]) -> pd.D """ df.to_csv(file_path, index=False) return df + + +@rank_zero_only +def unlink_file_or_folder(path: Path) -> None: + """Funtion to remove files and folders. + + Unlink works for files, rmdir for empty folders, but not for non-empty ones. Hence a + recursive solution. + """ + if path.exists(): + if path.is_file(): + path.unlink(missing_ok=True) + else: + shutil.rmtree(path) diff --git a/src/renate/utils/misc.py b/src/renate/utils/misc.py new file mode 100644 index 00000000..3e12ab5f --- /dev/null +++ b/src/renate/utils/misc.py @@ -0,0 +1,14 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Union + + +def int_or_str(x: str) -> Union[str, int]: + """Function to cast to int or str. + + This is used to tackle precision which can be int (16, 32) or str (bf16) + """ + try: + return int(x) + except ValueError: + return x diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 23516c07..a45766bc 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -5,6 +5,7 @@ from types import ModuleType from typing import Any, Callable, Dict, List, Optional, Union +import torch import torchmetrics from renate import defaults @@ -26,6 +27,8 @@ def evaluate_and_record_results( batch_size: int = defaults.BATCH_SIZE, accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, ) -> Dict[str, List[List[float]]]: """A helper function that performs the evaluation on test data and records quantitative metrics in a dictionary. @@ -58,6 +61,8 @@ def evaluate_and_record_results( logged_metrics=logged_metrics, accelerator=accelerator, devices=devices, + strategy=strategy, + precision=precision, ) for key, value in update_results.items(): if key not in results: @@ -76,6 +81,11 @@ def get_data_module(config_module: ModuleType, **kwargs: Any) -> RenateDataModul return getattr(config_module, "data_module_fn")(**kwargs) +def get_loss_fn(config_module: ModuleType, **kwargs: Any) -> torch.nn.Module: + """Creates and returns the loss function from config""" + return getattr(config_module, "loss_fn")(**kwargs) + + def get_metrics(config_module: ModuleType) -> Dict[str, torchmetrics.Metric]: """Creates and returns a dictionary of metrics.""" metrics_fn_name = "metrics_fn" diff --git a/test/conftest.py b/test/conftest.py index d3445ced..ad64b255 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -76,6 +76,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, Learner: { "optimizer": "SGD", @@ -84,6 +85,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.005, "batch_size": 10, "seed": 42, + "loss_fn": torch.nn.CrossEntropyLoss(), }, GDumbLearner: { "optimizer": "SGD", @@ -93,6 +95,7 @@ def pytest_collection_modifyitems(config, items): "batch_size": 10, "seed": 42, "memory_size": 30, + "loss_fn": torch.nn.CrossEntropyLoss(), }, JointLearner: { "optimizer": "SGD", @@ -101,6 +104,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.001, "batch_size": 10, "seed": 3, + "loss_fn": torch.nn.CrossEntropyLoss(), }, RepeatedDistillationLearner: { "optimizer": "SGD", @@ -110,6 +114,7 @@ def pytest_collection_modifyitems(config, items): "batch_size": 10, "seed": 42, "memory_size": 30, + "loss_fn": torch.nn.CrossEntropyLoss(), }, OfflineExperienceReplayLearner: { "memory_size": 30, @@ -121,6 +126,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, } AVALANCHE_LEARNER_KWARGS = { @@ -133,6 +139,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheEWCLearner: { "ewc_lambda": 0.1, @@ -142,6 +149,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheLwFLearner: { "alpha": 0.1, @@ -152,6 +160,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheICaRLLearner: { "memory_size": 30, @@ -162,6 +171,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, + "loss_fn": torch.nn.CrossEntropyLoss(), }, } LEARNER_HYPERPARAMETER_UPDATES = { @@ -171,12 +181,14 @@ def pytest_collection_modifyitems(config, items): "momentum": 0.5, "weight_decay": 0.01, "batch_size": 128, + "loss_fn": torch.nn.CrossEntropyLoss(), }, Learner: { "optimizer": "Adam", "learning_rate": 3.0, "weight_decay": 0.01, "batch_size": 128, + "loss_fn": torch.nn.CrossEntropyLoss(), }, GDumbLearner: { "optimizer": "Adam", @@ -185,18 +197,21 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.03, "batch_size": 128, "memory_size": 50, + "loss_fn": torch.nn.CrossEntropyLoss(), }, JointLearner: { "optimizer": "Adam", "learning_rate": 2.0, "weight_decay": 0.01, "batch_size": 128, + "loss_fn": torch.nn.CrossEntropyLoss(), }, RepeatedDistillationLearner: { "optimizer": "Adam", "learning_rate": 2.0, "weight_decay": 0.01, "batch_size": 128, + "loss_fn": torch.nn.CrossEntropyLoss(), }, OfflineExperienceReplayLearner: { "optimizer": "Adam", @@ -204,6 +219,7 @@ def pytest_collection_modifyitems(config, items): "momentum": 0.5, "weight_decay": 0.01, "batch_size": 128, + "loss_fn": torch.nn.CrossEntropyLoss(), }, } AVALANCHE_LEARNER_HYPERPARAMETER_UPDATES = { @@ -254,6 +270,11 @@ def get_renate_module_mlp( ) +@pytest.helpers.register +def get_loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() + + @pytest.helpers.register def get_renate_module_resnet(sub_class="resnet18cifar", **kwargs) -> RenateModule: kwargs["add_icarl_class_means"] = False @@ -338,6 +359,31 @@ def get_renate_module_mlp_and_data( return model, train_dataset, test_data +@pytest.helpers.register +def get_renate_module_mlp_data_and_loss( + num_inputs, + num_outputs, + num_hidden_layers, + hidden_size, + train_num_samples, + test_num_samples, + val_num_samples=0, + add_icarl_class_means=False, +): + model, ds, test_data = get_renate_module_mlp_and_data( + num_inputs, + num_outputs, + num_hidden_layers, + hidden_size, + train_num_samples, + test_num_samples, + val_num_samples, + add_icarl_class_means, + ) + + return model, ds, test_data, get_loss_fn() + + @pytest.helpers.register def get_renate_vision_module_and_data( input_size, @@ -363,7 +409,7 @@ def get_simple_updater( input_state_folder=None, output_state_folder=None, learner_class=ExperienceReplayLearner, - learner_kwargs={"memory_size": 10}, + learner_kwargs={"memory_size": 10, "loss_fn": pytest.helpers.get_loss_fn()}, max_epochs=5, train_transform=None, train_target_transform=None, @@ -406,7 +452,7 @@ def get_avalanche_updater( input_state_folder=None, output_state_folder=None, learner_class=AvalancheReplayLearner, - learner_kwargs={"memory_size": 10}, + learner_kwargs={"memory_size": 10, "loss_fn": torch.nn.CrossEntropyLoss()}, max_epochs=5, train_transform=None, train_target_transform=None, diff --git a/test/integration_tests/run_experiment.py b/test/integration_tests/run_experiment.py index ed0c76be..a4f4623a 100644 --- a/test/integration_tests/run_experiment.py +++ b/test/integration_tests/run_experiment.py @@ -44,7 +44,7 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): f"--test-suite", type=str, required=True, - choices=["quick"], + choices=["quick", "main"], help="Test suite that is run.", ) parser.add_argument( @@ -88,4 +88,6 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): max_time=args.max_time, seed=args.seed, job_name=args.job_name[:36], + devices=1, + strategy="ddp", ) diff --git a/test/renate/benchmark/test_experimentation.py b/test/renate/benchmark/test_experimentation.py index 9026fc53..9ce75250 100644 --- a/test/renate/benchmark/test_experimentation.py +++ b/test/renate/benchmark/test_experimentation.py @@ -19,8 +19,10 @@ def experiment_job_kwargs(): "mode": "max", "metric": "val_accuracy", "num_updates": 2, - "max_time": 15, + "max_time": 30, "seed": 0, + "accelerator": "cpu", + "devices": 1, } diff --git a/test/renate/models/test_renate_module.py b/test/renate/models/test_renate_module.py index 1122b5ef..318265eb 100644 --- a/test/renate/models/test_renate_module.py +++ b/test/renate/models/test_renate_module.py @@ -12,6 +12,7 @@ from renate.defaults import TASK_ID from renate.models import RenateModule from renate.models.renate_module import RenateWrapper +from renate.utils.deepspeed import recover_object_from_tensor def test_failing_to_init_abs_class(): @@ -36,13 +37,13 @@ def test_renate_model_save(tmpdir, model): state = torch.load(os.path.join(tmpdir, "test_model.pt")) os.remove(os.path.join(tmpdir, "test_model.pt")) + state["_extra_state"] = recover_object_from_tensor(state["_extra_state"]) assert "constructor_arguments" in state["_extra_state"].keys() - assert "loss_fn" in state["_extra_state"].keys() assert "tasks_params_ids" in state["_extra_state"].keys() @pytest.mark.parametrize( - "test_case,test_cls", + "test_case,test_cls,loss_fn", [ [ pytest.helpers.get_renate_module_mlp_and_data( @@ -54,6 +55,7 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), MultiLayerPerceptron, + torch.nn.CrossEntropyLoss, ], [ pytest.helpers.get_renate_vision_module_and_data( @@ -65,6 +67,7 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), ResNet, + torch.nn.CrossEntropyLoss, ], [ pytest.helpers.get_renate_vision_module_and_data( @@ -76,16 +79,17 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), VisionTransformer, + torch.nn.CrossEntropyLoss, ], ], ) -def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls): +def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls, loss_fn): model, _, test_data = test_case model.eval() y = torch.randint(3, 8, (5,)) y_hat_pre_save = model(test_data) - loss_pre_save = model.loss_fn(y_hat_pre_save, y) + loss_pre_save = loss_fn()(y_hat_pre_save, y) torch.save(model.state_dict(), os.path.join(tmpdir, "test_model.pt")) state = torch.load(os.path.join(tmpdir, "test_model.pt")) @@ -94,7 +98,7 @@ def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls): model2 = test_cls.from_state_dict(state) model2.eval() y_hat_post_load = model2(test_data) - loss_post_load = model.loss_fn(y_hat_post_load, y) + loss_post_load = loss_fn()(y_hat_post_load, y) assert torch.allclose(y_hat_pre_save, y_hat_post_load) assert torch.allclose(loss_post_load, loss_pre_save) @@ -313,7 +317,7 @@ def test_renate_multihead_multi_param_update(test_case): ) @torch.no_grad() def test_renate_wrapper_save_and_load(tmpdir, torch_model): - renate_module = RenateWrapper(torch_model, loss_fn=torch.nn.CrossEntropyLoss()) + renate_module = RenateWrapper(torch_model) X = torch.randn(10, 3) output_before = renate_module(X) assert torch.equal(output_before, torch_model(X)) @@ -323,7 +327,7 @@ def test_renate_wrapper_save_and_load(tmpdir, torch_model): state = torch.load(os.path.join(tmpdir, "test_model.pt")) os.remove(os.path.join(tmpdir, "test_model.pt")) - renate_module = RenateWrapper(torch_model, loss_fn=torch.nn.CrossEntropyLoss()) + renate_module = RenateWrapper(torch_model) renate_module.load_state_dict(state) output_after = renate_module(X) @@ -331,7 +335,7 @@ def test_renate_wrapper_save_and_load(tmpdir, torch_model): def test_renate_wrapper_forbids_from_state_dict(): - renate_module = RenateWrapper(torch.nn.Linear(1, 1), loss_fn=torch.nn.CrossEntropyLoss()) + renate_module = RenateWrapper(torch.nn.Linear(1, 1)) state_dict = renate_module.state_dict() del renate_module with pytest.raises(NotImplementedError): diff --git a/test/renate/renate_config_files/config.py b/test/renate/renate_config_files/config.py index 46b985a9..3bfdeb5f 100644 --- a/test/renate/renate_config_files/config.py +++ b/test/renate/renate_config_files/config.py @@ -28,3 +28,7 @@ def data_module_fn( bool_param: bool = False, ) -> RenateDataModule: return DummyTorchVisionDataModule(transform=None, val_size=val_size, seed=seed) + + +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() diff --git a/test/renate/renate_config_files/config_scenario.py b/test/renate/renate_config_files/config_scenario.py index 5e4f9449..70d81c6f 100644 --- a/test/renate/renate_config_files/config_scenario.py +++ b/test/renate/renate_config_files/config_scenario.py @@ -31,3 +31,7 @@ def data_module_fn( chunk_id=chunk_id, class_groupings=class_groupings, ) + + +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss() diff --git a/test/renate/training/test_run_training.py b/test/renate/training/test_run_training.py index dd1808d2..399d15cd 100644 --- a/test/renate/training/test_run_training.py +++ b/test/renate/training/test_run_training.py @@ -69,7 +69,7 @@ def execute_job(): "val_size": val_size, }, metric="val_accuracy", - max_time=15, + max_time=30, scheduler=scheduler, ) diff --git a/test/renate/updaters/avalanche/test_avalanche_learner.py b/test/renate/updaters/avalanche/test_avalanche_learner.py index 1d2ed0c9..ae6aaebc 100644 --- a/test/renate/updaters/avalanche/test_avalanche_learner.py +++ b/test/renate/updaters/avalanche/test_avalanche_learner.py @@ -42,7 +42,7 @@ def check_learner_settings( assert avalanche_learner.eval_every == expected_eval_every assert avalanche_learner.model == expected_model if expected_loss_fn is None: - assert avalanche_learner._criterion == expected_model.loss_fn + assert avalanche_learner._criterion == learner_kwargs["loss_fn"] else: assert avalanche_learner._criterion == expected_loss_fn assert avalanche_learner.optimizer == expected_optimizer diff --git a/test/renate/updaters/avalanche/test_avalanche_model_updater.py b/test/renate/updaters/avalanche/test_avalanche_model_updater.py index 97424912..e0df1d89 100644 --- a/test/renate/updaters/avalanche/test_avalanche_model_updater.py +++ b/test/renate/updaters/avalanche/test_avalanche_model_updater.py @@ -87,12 +87,14 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b "memory_size": memory_size, "memory_batch_size": memory_batch_size, "batch_size": batch_size, + "loss_fn": pytest.helpers.get_loss_fn(), } model_updater = ExperienceReplayAvalancheModelUpdater( output_state_folder=Path(tmpdir) / "0", model=model, **learner_kwargs, max_epochs=1, + accelerator="cpu", ) model_updater.update(train_dataset=dataset) replay_plugin = plugin_by_class(ReplayPlugin, model_updater._learner.plugins) diff --git a/test/renate/updaters/experimental/test_er.py b/test/renate/updaters/experimental/test_er.py index 8434d6a3..ae89346e 100644 --- a/test/renate/updaters/experimental/test_er.py +++ b/test/renate/updaters/experimental/test_er.py @@ -36,6 +36,7 @@ def test_er_overall_memory_size_after_update(batch_size, memory_size, memory_bat "memory_size": memory_size, "memory_batch_size": memory_batch_size, "batch_size": batch_size, + "loss_fn": pytest.helpers.get_loss_fn(), } model_updater = pytest.helpers.get_simple_updater( model=model, @@ -91,8 +92,19 @@ def test_er_validation_buffer(tmpdir): @pytest.mark.parametrize( "cls,kwargs", [ - [ExperienceReplayLearner, {"alpha": 0.2, "memory_size": 10, "memory_batch_size": 10}], - [DarkExperienceReplayLearner, {"alpha": 0.1, "beta": 0.3, "memory_size": 42}], + [ + ExperienceReplayLearner, + { + "alpha": 0.2, + "memory_size": 10, + "memory_batch_size": 10, + "loss_fn": pytest.helpers.get_loss_fn(), + }, + ], + [ + DarkExperienceReplayLearner, + {"alpha": 0.1, "beta": 0.3, "memory_size": 42, "loss_fn": pytest.helpers.get_loss_fn()}, + ], [ CLSExperienceReplayLearner, { @@ -103,11 +115,18 @@ def test_er_validation_buffer(tmpdir): "stable_model_update_probability": 0.3, "plastic_model_update_probability": 0.5, "memory_size": 42, + "loss_fn": pytest.helpers.get_loss_fn(), }, ], [ PooledOutputDistillationExperienceReplayLearner, - {"alpha": 0.3, "distillation_type": "pixel", "normalize": False, "memory_size": 42}, + { + "alpha": 0.3, + "distillation_type": "pixel", + "normalize": False, + "memory_size": 42, + "loss_fn": pytest.helpers.get_loss_fn(), + }, ], [ SuperExperienceReplayLearner, @@ -125,6 +144,7 @@ def test_er_validation_buffer(tmpdir): "pod_distillation_type": "pixel", "pod_normalize": False, "memory_size": 42, + "loss_fn": pytest.helpers.get_loss_fn(), }, ], ], @@ -137,8 +157,8 @@ def test_er_components_save_and_load(tmpdir, cls, kwargs): ) learner = cls(model=model, **kwargs) torch.save(learner.state_dict(), os.path.join(tmpdir, "learner.pt")) - learner = cls.__new__(cls) - learner.load_state_dict(model, torch.load(os.path.join(tmpdir, "learner.pt"))) + learner = cls(model=model, **kwargs) + learner.load_state_dict(torch.load(os.path.join(tmpdir, "learner.pt"))) if isinstance(learner, ExperienceReplayLearner) and not isinstance( learner, DarkExperienceReplayLearner ): diff --git a/test/renate/updaters/experimental/test_fine_tuning.py b/test/renate/updaters/experimental/test_fine_tuning.py index 17a29962..18098686 100644 --- a/test/renate/updaters/experimental/test_fine_tuning.py +++ b/test/renate/updaters/experimental/test_fine_tuning.py @@ -20,13 +20,23 @@ def get_model_and_dataset(): return model, dataset -def test_fine_tuning_updater(): +@pytest.mark.parametrize("devices", [None, 1]) +@pytest.mark.parametrize("strategy", [None, "ddp"]) +@pytest.mark.parametrize("accelerator", ["cpu"]) +def test_fine_tuning_updater(devices, strategy, accelerator): """This test checks that the memory buffer is updated correctly.""" init_model, dataset = get_model_and_dataset() model = copy.deepcopy(init_model) - model_updater = FineTuningModelUpdater(model, max_epochs=1) + model_updater = FineTuningModelUpdater( + model, + loss_fn=pytest.helpers.get_loss_fn(), + max_epochs=1, + devices=devices, + strategy=strategy, + accelerator=accelerator, + ) model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) for p1, p2 in zip(init_model.parameters(), model.parameters()): diff --git a/test/renate/updaters/experimental/test_joint.py b/test/renate/updaters/experimental/test_joint.py index cdb0f2f4..597251c5 100644 --- a/test/renate/updaters/experimental/test_joint.py +++ b/test/renate/updaters/experimental/test_joint.py @@ -27,7 +27,7 @@ def test_joint_learner_memory_append(): model_updater = pytest.helpers.get_simple_updater( model=model, learner_class=JointLearner, - learner_kwargs={}, + learner_kwargs={"loss_fn": pytest.helpers.get_loss_fn()}, max_epochs=1, ) model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) @@ -42,7 +42,7 @@ def test_joint_learner_model_reset(): model_updater = pytest.helpers.get_simple_updater( model=model, learner_class=JointLearner, - learner_kwargs={"learning_rate": 0.0}, + learner_kwargs={"learning_rate": 0.0, "loss_fn": torch.nn.CrossEntropyLoss()}, max_epochs=1, ) model = model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) diff --git a/test/renate/updaters/experimental/test_repeated_distill.py b/test/renate/updaters/experimental/test_repeated_distill.py index f9ce542e..4d48d0cc 100644 --- a/test/renate/updaters/experimental/test_repeated_distill.py +++ b/test/renate/updaters/experimental/test_repeated_distill.py @@ -28,12 +28,9 @@ def test_dmc_runs_end_to_end(): data.append(ds) val = ConcatDataset(data) - + loss_fn = pytest.helpers.get_loss_fn() updater = RepeatedDistillationModelUpdater( - model=mlp, - memory_size=300, - batch_size=20, - max_epochs=5, + model=mlp, memory_size=300, batch_size=20, max_epochs=5, loss_fn=loss_fn, accelerator="cpu" ) for i in range(len(data)): @@ -46,10 +43,9 @@ def test_dmc_memory_size_after_update(memory_size, dataset_size): model = pytest.helpers.get_renate_module_mlp( num_inputs=10, num_outputs=3, hidden_size=20, num_hidden_layers=1 ) + loss_fn = pytest.helpers.get_loss_fn() model_updater = RepeatedDistillationModelUpdater( - model=model, - memory_size=memory_size, - max_epochs=1, + model=model, memory_size=memory_size, max_epochs=1, loss_fn=loss_fn, accelerator="cpu" ) datasets = [ TensorDataset( @@ -67,7 +63,7 @@ def test_dmc_memory_size_after_update(memory_size, dataset_size): @pytest.mark.parametrize("provide_folder", [True, False]) def test_dmc_model_updater(tmpdir, provide_folder): - model, train_dataset, test_data = pytest.helpers.get_renate_module_mlp_and_data( + model, train_dataset, test_data, loss_fn = pytest.helpers.get_renate_module_mlp_data_and_loss( num_inputs=10, num_outputs=10, hidden_size=32, @@ -77,9 +73,11 @@ def test_dmc_model_updater(tmpdir, provide_folder): ) model_updater = RepeatedDistillationModelUpdater( model, + loss_fn=loss_fn, memory_size=50, max_epochs=1, output_state_folder=defaults.output_state_folder(tmpdir) if provide_folder else None, + accelerator="cpu", ) y_hat_before_train = model(test_data, task_id=defaults.TASK_ID) model_updater.update(train_dataset, task_id=defaults.TASK_ID) @@ -89,7 +87,7 @@ def test_dmc_model_updater(tmpdir, provide_folder): def test_continuation_of_training_with_dmc_model_updater(tmpdir): - model, train_dataset, _ = pytest.helpers.get_renate_module_mlp_and_data( + model, train_dataset, _, loss_fn = pytest.helpers.get_renate_module_mlp_data_and_loss( num_inputs=10, num_outputs=10, hidden_size=32, @@ -99,11 +97,21 @@ def test_continuation_of_training_with_dmc_model_updater(tmpdir): ) state_url = defaults.input_state_folder(tmpdir) model_updater = RepeatedDistillationModelUpdater( - model, memory_size=50, max_epochs=1, output_state_folder=state_url + model, + memory_size=50, + max_epochs=1, + output_state_folder=state_url, + loss_fn=loss_fn, + accelerator="cpu", ) model = model_updater.update(train_dataset, task_id=defaults.TASK_ID) model_updater = RepeatedDistillationModelUpdater( - model, memory_size=50, max_epochs=1, input_state_folder=state_url + model, + memory_size=50, + max_epochs=1, + input_state_folder=state_url, + loss_fn=loss_fn, + accelerator="cpu", ) model_updater.update(train_dataset, task_id=defaults.TASK_ID) diff --git a/test/renate/updaters/test_learner.py b/test/renate/updaters/test_learner.py index 8194b0fd..e95cffcd 100644 --- a/test/renate/updaters/test_learner.py +++ b/test/renate/updaters/test_learner.py @@ -1,20 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import os from typing import Any, Dict, Tuple, Type import pytest -import torch -from torchvision.transforms import ToTensor +from conftest import LEARNER_KWARGS, LEARNERS -from conftest import ( - LEARNERS, - LEARNER_HYPERPARAMETER_UPDATES, - LEARNER_KWARGS, - check_learner_transforms, -) from renate.models import RenateModule -from renate.updaters.learner import Learner, ReplayLearner +from renate.updaters.learner import Learner def get_model_and_learner_and_learner_kwargs( @@ -30,7 +22,12 @@ def get_model_and_learner_and_learner_kwargs( def check_learner_variables(learner: Learner, expected_variable_values: Dict[str, Any]): for attribute_name, attribute_value in expected_variable_values.items(): - if attribute_name == "memory_size": + if attribute_name in [ + "memory_size", + "learner_class_name", + "val_memory_buffer", + "memory_buffer", + ]: continue assert getattr(learner, f"_{attribute_name}") == attribute_value @@ -38,38 +35,6 @@ def check_learner_variables(learner: Learner, expected_variable_values: Dict[str @pytest.mark.parametrize("learner_class", LEARNERS) def test_save_and_load_learner(tmpdir, learner_class): model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) - filename = os.path.join(tmpdir, "learner.pkl") - torch.save(learner.state_dict(), filename) - learner = learner_class.__new__(learner_class) - learner.load_state_dict(model, torch.load(filename)) - check_learner_variables(learner, learner_kwargs) - - -@pytest.mark.parametrize("learner_class", LEARNERS) -def test_update_hyperparameters(learner_class): - model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) - check_learner_variables(learner, learner_kwargs) - learner.update_hyperparameters({}) - check_learner_variables(learner, learner_kwargs) - learner.update_hyperparameters(LEARNER_HYPERPARAMETER_UPDATES[learner_class]) - learner_kwargs = dict(learner_kwargs) - learner_kwargs.update(LEARNER_HYPERPARAMETER_UPDATES[learner_class]) - check_learner_variables(learner, learner_kwargs) - - -@pytest.mark.parametrize("learner_class", LEARNERS) -def test_set_transforms(learner_class): - """Tests if set_transforms function correctly sets transforms in Learner and MemoryBuffer.""" - model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) - check_learner_transforms(learner, {}) - transforms_kwargs = { - "train_transform": ToTensor(), - "train_target_transform": ToTensor(), - "test_transform": ToTensor(), - "test_target_transform": ToTensor(), - } - if issubclass(learner_class, ReplayLearner): - transforms_kwargs["buffer_transform"] = ToTensor() - transforms_kwargs["buffer_target_transform"] = ToTensor() - learner.set_transforms(**transforms_kwargs) - check_learner_transforms(learner, transforms_kwargs) + checkpoint_dict = {} + learner.on_save_checkpoint(checkpoint=checkpoint_dict) + check_learner_variables(learner, checkpoint_dict) diff --git a/test/renate/updaters/test_model_updater.py b/test/renate/updaters/test_model_updater.py index fe36583e..66b1dfc5 100644 --- a/test/renate/updaters/test_model_updater.py +++ b/test/renate/updaters/test_model_updater.py @@ -73,7 +73,7 @@ def test_deterministic_updater(): def test_model_updater_with_early_stopping( use_val, early_stopping_enabled, metric_monitored, updater_type ): - model, train_dataset, val_dataset = pytest.helpers.get_renate_module_mlp_and_data( + model, train_dataset, val_dataset, loss = pytest.helpers.get_renate_module_mlp_data_and_loss( num_inputs=10, num_outputs=10, hidden_size=8, @@ -88,10 +88,13 @@ def test_model_updater_with_early_stopping( if updater_type == "DMC": model_updater = RepeatedDistillationModelUpdater( model=model, + loss_fn=loss, memory_size=50, max_epochs=max_epochs, early_stopping_enabled=early_stopping_enabled, metric=metric_monitored, + accelerator="cpu", + devices=1, ) else: model_updater = pytest.helpers.get_simple_updater( @@ -204,6 +207,23 @@ def identity_transform(): check_learner_transforms(model_updater._learner, transforms_kwargs) model = model_updater.update(train_dataset, task_id=defaults.TASK_ID) model_updater = pytest.helpers.get_simple_updater( - model, input_state_folder=state_url, learner_class=learner_class, **transforms_kwargs + model, + input_state_folder=state_url, + learner_class=learner_class, + **transforms_kwargs, + learner_kwargs=LEARNER_KWARGS[learner_class], ) check_learner_transforms(model_updater._learner, transforms_kwargs) + + +@pytest.mark.xfail(raises=(KeyError, TypeError)) +@pytest.mark.parametrize("learner", LEARNERS_USING_SIMPLE_UPDATER) +def test_learner_fails_without_loss_fn(learner): + """This test checks that the updater crashes when it is supposed to. + This is to check for loss_fn and other misc arguments.""" + _ = pytest.helpers.get_simple_updater( + model=torch.nn.Linear(1, 1), + learner_class=learner, + learner_kwargs={"learning_rate": 0.0}, + max_epochs=1, + ) diff --git a/test/renate/utils/test_deepspeed.py b/test/renate/utils/test_deepspeed.py new file mode 100644 index 00000000..ce01493f --- /dev/null +++ b/test/renate/utils/test_deepspeed.py @@ -0,0 +1,26 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from renate.utils.deepspeed import convert_to_tensor, recover_object_from_tensor + + +@pytest.mark.parametrize( + "obj", + [ + { + "constructor_arguments": { + "a": 1, + "b": "2", + "nested_dict": { + "k1": "v1", + "k2": "v2", + }, + }, + "tasks_params_ids": "task_param_id_1", + "misc_args": tuple(range(10)), + } + ], +) +def test_serialize_random_objects(obj): + assert recover_object_from_tensor(convert_to_tensor(obj)) == obj diff --git a/test/renate/utils/test_distributed_strategies.py b/test/renate/utils/test_distributed_strategies.py new file mode 100644 index 00000000..716e19c9 --- /dev/null +++ b/test/renate/utils/test_distributed_strategies.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from renate.utils.distributed_strategies import ( + create_strategy, + _SUPPORTED_STRATEGIES, + _UNSUPPORTED_STRATEGIES, +) + + +@pytest.mark.parametrize("devices", [1, 2, 3, 10]) +@pytest.mark.parametrize("strategy_name", _SUPPORTED_STRATEGIES) +def test_valid_strategy_creation(devices, strategy_name): + assert isinstance(create_strategy(devices, strategy_name), (str, object, None)) + + +@pytest.mark.parametrize("devices", [1, 2, 3, 10]) +@pytest.mark.parametrize("strategy_name", _UNSUPPORTED_STRATEGIES) +def test_invalid_strategy_creation(devices, strategy_name): + with pytest.raises(ValueError) as _: + create_strategy(devices, strategy_name) diff --git a/test/renate/utils/test_misc.py b/test/renate/utils/test_misc.py new file mode 100644 index 00000000..d02d248e --- /dev/null +++ b/test/renate/utils/test_misc.py @@ -0,0 +1,19 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from renate.utils.misc import int_or_str + + +@pytest.mark.parametrize( + "data_type,target", + [ + ["16", 16], + ["32", 32], + ["bfloat", "bfloat"], + ["bfloat16", "bfloat16"], + ["notdata", "notdata"], + ], +) +def test_int_or_str(data_type, target): + assert int_or_str(data_type) == target From 699ab3edbb2901c94c2673595231ec3769a5df5d Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Fri, 26 May 2023 14:03:38 +0200 Subject: [PATCH 18/89] Remove obsolete `set_transforms` from memory buffer (#265) --- src/renate/memory/buffer.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/renate/memory/buffer.py b/src/renate/memory/buffer.py index 3849c9ab..f20480d9 100644 --- a/src/renate/memory/buffer.py +++ b/src/renate/memory/buffer.py @@ -126,13 +126,6 @@ def set_metadata(self, key: str, values: torch.Tensor) -> None: self._add_metadata_like({key: values}) self._metadata[key][: len(self)] = values.cpu() - def set_transforms( - self, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None - ) -> None: - """Update the transformations applied to the data.""" - self._transform = transform - self._target_transform = target_transform - def state_dict(self) -> Dict: return { "buffer_class_name": self.__class__.__name__, @@ -173,7 +166,7 @@ def save(self, target_dir: str) -> None: storage[i] = self[i][0] # Drop metadata. self._datasets = [storage] self._indices = {i: (0, i) for i in range(len(self))} - self.set_transforms(*transforms) + self._transform, self._target_transform = transforms def load(self, source_dir: str) -> None: if not len(self): From 25c9144a319a14db83a651f0bf5941c1879460d6 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 31 May 2023 10:50:58 +0200 Subject: [PATCH 19/89] Missing dependency and problem with import (#272) --- requirements.txt | 1 + src/renate/benchmark/datasets/wild_time_data.py | 8 ++++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6da017dc..f7c22ed9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,7 @@ syne-tune[aws,gpsearchers]==0.6.0 pytorch-lightning>=1.8.0, <1.9.5 Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 +tensorboardX>=2.5.0, <2.5.2 torchmetrics~=0.10.3 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py index d1727c5f..2655ae85 100644 --- a/src/renate/benchmark/datasets/wild_time_data.py +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -4,8 +4,6 @@ from typing import Any, Dict, Optional, Union from transformers import PreTrainedTokenizer -from wild_time_data import load_dataset -from wild_time_data.core import available_time_steps, dataset_classes from renate import defaults from renate.data.data_module import RenateDataModule @@ -63,6 +61,8 @@ def prepare_data(self) -> None: If s3 bucket is given, the data is downloaded from s3, otherwise from the original source. """ if self._src_bucket is None: + from wild_time_data import available_time_steps, load_dataset + load_dataset( dataset_name=self._dataset_name, time_step=available_time_steps(self._dataset_name)[0], @@ -70,6 +70,8 @@ def prepare_data(self) -> None: data_dir=self._data_path, ) else: + from wild_time_data.core import dataset_classes + dst_dir = Path(self._data_path) / dataset_classes[self._dataset_name].file_name if not dst_dir.exists(): download_folder_from_s3( @@ -80,6 +82,8 @@ def prepare_data(self) -> None: def setup(self) -> None: """Set up train, test and val datasets.""" + from wild_time_data import available_time_steps, load_dataset + kwargs = { "dataset_name": self._dataset_name, "time_step": available_time_steps(self._dataset_name)[self.time_step], From 34ba409a7b2e463f5109cdd5193ddbbaedcd9ed5 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 31 May 2023 17:03:20 +0200 Subject: [PATCH 20/89] Fix Offline-ER bug and change loss functions (#273) --- doc/getting_started/how_to_renate_config.rst | 3 +- examples/getting_started/renate_config.py | 2 +- examples/nlp_finetuning/renate_config.py | 2 +- .../renate_config.py | 4 +++ examples/train_mlp_locally/renate_config.py | 4 +++ src/renate/benchmark/experiment_config.py | 6 ++-- src/renate/benchmark/scenarios.py | 2 +- src/renate/cli/run_training.py | 1 + src/renate/updaters/experimental/er.py | 2 +- .../updaters/experimental/offline_er.py | 30 ++++++++++++++----- src/renate/updaters/learner.py | 2 +- src/renate/utils/module.py | 18 +++++++++-- test/conftest.py | 28 ++++++++--------- .../configs/suites/quick/offline-er.json | 4 +-- .../benchmark/test_experimentation_config.py | 6 ++++ test/renate/models/test_renate_module.py | 13 ++++---- test/renate/renate_config_files/config.py | 6 ++-- .../renate_config_files/config_scenario.py | 6 ++-- .../avalanche/test_avalanche_model_updater.py | 2 +- .../updaters/experimental/test_joint.py | 5 +++- 20 files changed, 99 insertions(+), 47 deletions(-) diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index f5f0915a..6dbb692a 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -59,12 +59,13 @@ signature def loss_fn() -> torch.nn.Module: -An example of this for the task of MNIST classfication above as +An example of this for the task of MNIST classification above as .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Loss function example :lines: 95-96 +Please note, loss functions should not be reduced. Data Preparation ================ diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index 1e931819..756fc729 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -93,4 +93,4 @@ def metrics_fn() -> Dict: def loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") diff --git a/examples/nlp_finetuning/renate_config.py b/examples/nlp_finetuning/renate_config.py index bcb40a3b..5b1b4e79 100644 --- a/examples/nlp_finetuning/renate_config.py +++ b/examples/nlp_finetuning/renate_config.py @@ -25,7 +25,7 @@ def model_fn(model_state_url: Optional[str] = None) -> RenateModule: def loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> RenateDataModule: diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 563d46e7..49567c9d 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -61,3 +61,7 @@ def test_transform() -> Callable: def buffer_transform() -> Callable: """Returns a transform function to be used in the Memory Buffer.""" return train_transform() + + +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss(reduction="none") diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index 779f4ef3..772577d3 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -48,3 +48,7 @@ def model_fn(model_state_url: Optional[str] = None) -> RenateModule: def train_transform() -> Callable: """Returns a transform function to be used in the training.""" return transforms.Lambda(lambda x: torch.flatten(x)) + + +def loss_fn() -> torch.nn.Module: + return torch.nn.CrossEntropyLoss(reduction="none") diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 1dbb5b11..0a185d12 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -231,8 +231,10 @@ def get_scenario( raise ValueError(f"Unknown scenario `{scenario_name}`.") -def loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() +def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: + if updater.startswith("Avalanche-"): + return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") def data_module_fn( diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index 7e45c061..b90031fd 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -117,7 +117,7 @@ def __init__( self, data_module: RenateDataModule, chunk_id: int, - class_groupings: List[List[int]], + class_groupings: Tuple[Tuple[int]], ) -> None: super().__init__(data_module, len(class_groupings), chunk_id) self._class_groupings = class_groupings diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index 04e794e6..56e38756 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -122,6 +122,7 @@ def run(self): ) loss_fn = get_loss_fn( config_module, + not args.updater.startswith("Avalanche-"), **get_function_kwargs(args=args, function_args=function_args["loss_fn"]), ) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index 99375860..7dd279a0 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -135,7 +135,7 @@ def training_step( outputs_memory=outputs_memory, batch_memory=batch_memory, intermediate_representation_memory=intermediate_representation_memory, - ) + ).mean() self._loss_collections["train_losses"][name](component_loss) step_output["loss"] += component_loss loss_normalization += component.weight diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index dc2779e1..2c505b01 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -14,6 +14,7 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.pytorch import move_tensors_to_device class OfflineExperienceReplayLearner(ReplayLearner): @@ -96,16 +97,31 @@ def training_step( else: alpha = self._loss_weight_new_data inputs, targets = batch["current_task"] + device = inputs.device + batch_size_current = inputs.shape[0] + batch_size_mem = 0 + if "memory" in batch: + (inputs_mem, targets_mem), _ = batch["memory"] + batch_size_mem = inputs_mem.shape[0] + inputs = torch.cat((inputs, inputs_mem), 0) + targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) loss = self._loss_fn(outputs, targets) - self._loss_collections["train_losses"]["base_loss"](loss) - self._update_metrics(outputs, targets, "train") if "memory" in batch: - (inputs_mem, targets_mem), _ = batch["memory"] - outputs_mem = self(inputs_mem) - loss_mem = self._loss_fn(outputs_mem, targets_mem) - self._loss_collections["train_losses"]["memory_loss"](loss_mem) - loss = alpha * loss + (1.0 - alpha) * loss_mem + weights = torch.Tensor( + [ + [alpha for _ in range(batch_size_current)] + + [(1 - alpha) for _ in range(batch_size_mem)] + ] + ) + self._loss_collections["train_losses"]["memory_loss"](loss[batch_size_current:].mean()) + self._loss_collections["train_losses"]["base_loss"](loss[:batch_size_current].mean()) + weights = move_tensors_to_device(weights, device=device) + loss = weights / weights.mean() * loss + else: + self._loss_collections["train_losses"]["base_loss"](loss[:batch_size_current].mean()) + loss = loss.mean() + self._update_metrics(outputs, targets, "train") return {"loss": loss} def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index 7a1c346b..3a5211fb 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -238,7 +238,7 @@ def training_step( outputs = self(inputs) intermediate_representation = self._model.get_intermediate_representation() self._model.reset_intermediate_representation_cache() - loss = self._loss_fn(outputs, targets) + loss = self._loss_fn(outputs, targets).mean() self._update_metrics(outputs, targets, "train") self._loss_collections["train_losses"]["base_loss"](loss) return { diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index a45766bc..2434c680 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import importlib.util import sys +import warnings from types import ModuleType from typing import Any, Callable, Dict, List, Optional, Union @@ -81,9 +82,22 @@ def get_data_module(config_module: ModuleType, **kwargs: Any) -> RenateDataModul return getattr(config_module, "data_module_fn")(**kwargs) -def get_loss_fn(config_module: ModuleType, **kwargs: Any) -> torch.nn.Module: +def _convert_loss(loss_fn: torch.nn.Module): + """Changes PyTorch loss such that it uses no reduction.""" + if hasattr(loss_fn, "reduction") and loss_fn.reduction != "none": + warnings.warn( + "Renate assumes that your loss function returns a loss value for each data point." + f"Your loss function uses reduction={loss_fn.reduction}, changing to `none`." + ) + loss_fn.reduction = "none" + + +def get_loss_fn(config_module: ModuleType, convert: bool, **kwargs: Any) -> torch.nn.Module: """Creates and returns the loss function from config""" - return getattr(config_module, "loss_fn")(**kwargs) + loss_fn = getattr(config_module, "loss_fn")(**kwargs) + if convert: + _convert_loss(loss_fn) + return loss_fn def get_metrics(config_module: ModuleType) -> Dict[str, torchmetrics.Metric]: diff --git a/test/conftest.py b/test/conftest.py index ad64b255..659aa990 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -76,7 +76,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, Learner: { "optimizer": "SGD", @@ -85,7 +85,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.005, "batch_size": 10, "seed": 42, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, GDumbLearner: { "optimizer": "SGD", @@ -95,7 +95,7 @@ def pytest_collection_modifyitems(config, items): "batch_size": 10, "seed": 42, "memory_size": 30, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, JointLearner: { "optimizer": "SGD", @@ -104,7 +104,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.001, "batch_size": 10, "seed": 3, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, RepeatedDistillationLearner: { "optimizer": "SGD", @@ -114,7 +114,7 @@ def pytest_collection_modifyitems(config, items): "batch_size": 10, "seed": 42, "memory_size": 30, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, OfflineExperienceReplayLearner: { "memory_size": 30, @@ -126,7 +126,7 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, } AVALANCHE_LEARNER_KWARGS = { @@ -181,14 +181,14 @@ def pytest_collection_modifyitems(config, items): "momentum": 0.5, "weight_decay": 0.01, "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, Learner: { "optimizer": "Adam", "learning_rate": 3.0, "weight_decay": 0.01, "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, GDumbLearner: { "optimizer": "Adam", @@ -197,21 +197,21 @@ def pytest_collection_modifyitems(config, items): "weight_decay": 0.03, "batch_size": 128, "memory_size": 50, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, JointLearner: { "optimizer": "Adam", "learning_rate": 2.0, "weight_decay": 0.01, "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, RepeatedDistillationLearner: { "optimizer": "Adam", "learning_rate": 2.0, "weight_decay": 0.01, "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, OfflineExperienceReplayLearner: { "optimizer": "Adam", @@ -219,7 +219,7 @@ def pytest_collection_modifyitems(config, items): "momentum": 0.5, "weight_decay": 0.01, "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(), + "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, } AVALANCHE_LEARNER_HYPERPARAMETER_UPDATES = { @@ -271,8 +271,8 @@ def get_renate_module_mlp( @pytest.helpers.register -def get_loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() +def get_loss_fn(reduction="none") -> torch.nn.Module: + return torch.nn.CrossEntropyLoss(reduction=reduction) @pytest.helpers.register diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 6673d49d..c0b04935 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.7039999961853027, 0.4569999873638153], [0.6664999723434448, 0.5544999837875366]], - "expected_accuracy_darwin": [[0.6965000033378601, 0.4284999966621399]] + "expected_accuracy_linux": [[0.765500009059906, 0.4020000100135803], [0.7574999928474426, 0.4424999952316284]], + "expected_accuracy_darwin": [[0.7549999952316284, 0.45249998569488525]] } diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index d07f2af3..0bb06f2f 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -10,6 +10,7 @@ data_module_fn, get_data_module, get_scenario, + loss_fn, model_fn, models, train_transform, @@ -306,3 +307,8 @@ def test_prediction_strategy_is_correctly_set(model_name, updater): assert not hasattr(model, "_prediction_strategy") or model._prediction_strategy is None else: assert isinstance(model._prediction_strategy, ICaRLClassificationStrategy) + + +def test_loss_fn_returns_correct_reduction_type(): + assert loss_fn("ER").reduction == "none" + assert loss_fn("Avalanche-ER").reduction == "mean" diff --git a/test/renate/models/test_renate_module.py b/test/renate/models/test_renate_module.py index 318265eb..46b6edd7 100644 --- a/test/renate/models/test_renate_module.py +++ b/test/renate/models/test_renate_module.py @@ -17,7 +17,7 @@ def test_failing_to_init_abs_class(): with pytest.raises(TypeError): - RenateModule({"toy_hyperparam": 1.0}, torch.nn.CrossEntropyLoss()) + RenateModule({"toy_hyperparam": 1.0}, pytest.helpers.get_loss_fn()) @pytest.mark.parametrize( @@ -43,7 +43,7 @@ def test_renate_model_save(tmpdir, model): @pytest.mark.parametrize( - "test_case,test_cls,loss_fn", + "test_case,test_cls", [ [ pytest.helpers.get_renate_module_mlp_and_data( @@ -55,7 +55,6 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), MultiLayerPerceptron, - torch.nn.CrossEntropyLoss, ], [ pytest.helpers.get_renate_vision_module_and_data( @@ -67,7 +66,6 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), ResNet, - torch.nn.CrossEntropyLoss, ], [ pytest.helpers.get_renate_vision_module_and_data( @@ -79,17 +77,16 @@ def test_renate_model_save(tmpdir, model): test_num_samples=5, ), VisionTransformer, - torch.nn.CrossEntropyLoss, ], ], ) -def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls, loss_fn): +def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls): model, _, test_data = test_case model.eval() y = torch.randint(3, 8, (5,)) y_hat_pre_save = model(test_data) - loss_pre_save = loss_fn()(y_hat_pre_save, y) + loss_pre_save = pytest.helpers.get_loss_fn("mean")(y_hat_pre_save, y) torch.save(model.state_dict(), os.path.join(tmpdir, "test_model.pt")) state = torch.load(os.path.join(tmpdir, "test_model.pt")) @@ -98,7 +95,7 @@ def test_renate_model_singlehead_save_and_load(tmpdir, test_case, test_cls, loss model2 = test_cls.from_state_dict(state) model2.eval() y_hat_post_load = model2(test_data) - loss_post_load = loss_fn()(y_hat_post_load, y) + loss_post_load = pytest.helpers.get_loss_fn("mean")(y_hat_post_load, y) assert torch.allclose(y_hat_pre_save, y_hat_post_load) assert torch.allclose(loss_post_load, loss_pre_save) diff --git a/test/renate/renate_config_files/config.py b/test/renate/renate_config_files/config.py index 3bfdeb5f..14b4a435 100644 --- a/test/renate/renate_config_files/config.py +++ b/test/renate/renate_config_files/config.py @@ -30,5 +30,7 @@ def data_module_fn( return DummyTorchVisionDataModule(transform=None, val_size=val_size, seed=seed) -def loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() +def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: + if updater.startswith("Avalanche-"): + return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") diff --git a/test/renate/renate_config_files/config_scenario.py b/test/renate/renate_config_files/config_scenario.py index 70d81c6f..cad1efd1 100644 --- a/test/renate/renate_config_files/config_scenario.py +++ b/test/renate/renate_config_files/config_scenario.py @@ -33,5 +33,7 @@ def data_module_fn( ) -def loss_fn() -> torch.nn.Module: - return torch.nn.CrossEntropyLoss() +def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: + if updater.startswith("Avalanche-"): + return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") diff --git a/test/renate/updaters/avalanche/test_avalanche_model_updater.py b/test/renate/updaters/avalanche/test_avalanche_model_updater.py index e0df1d89..000d4bda 100644 --- a/test/renate/updaters/avalanche/test_avalanche_model_updater.py +++ b/test/renate/updaters/avalanche/test_avalanche_model_updater.py @@ -87,7 +87,7 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b "memory_size": memory_size, "memory_batch_size": memory_batch_size, "batch_size": batch_size, - "loss_fn": pytest.helpers.get_loss_fn(), + "loss_fn": pytest.helpers.get_loss_fn("mean"), } model_updater = ExperienceReplayAvalancheModelUpdater( output_state_folder=Path(tmpdir) / "0", diff --git a/test/renate/updaters/experimental/test_joint.py b/test/renate/updaters/experimental/test_joint.py index 597251c5..c31e69c9 100644 --- a/test/renate/updaters/experimental/test_joint.py +++ b/test/renate/updaters/experimental/test_joint.py @@ -42,7 +42,10 @@ def test_joint_learner_model_reset(): model_updater = pytest.helpers.get_simple_updater( model=model, learner_class=JointLearner, - learner_kwargs={"learning_rate": 0.0, "loss_fn": torch.nn.CrossEntropyLoss()}, + learner_kwargs={ + "learning_rate": 0.0, + "loss_fn": pytest.helpers.get_loss_fn(), + }, max_epochs=1, ) model = model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) From 6e29acd5acd1e4e675373cee8326eb4c6301968b Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 1 Jun 2023 14:28:17 +0200 Subject: [PATCH 21/89] Longer Experiments for GPUs (#246) --- .../scenarios/class-incremental-5updates.json | 7 +++ .../scenarios/feature-sorting-5updates.json | 8 +++ .../configs/scenarios/hue-shift-5updates.json | 8 +++ .../configs/scenarios/iid-5updates.json | 6 ++ .../scenarios/permutation-10updates.json | 7 +++ .../configs/scenarios/rotation-10updates.json | 3 +- .../configs/suites/main/avalanche-er.json | 9 +++ .../configs/suites/main/avalanche-ewc.json | 9 +++ .../configs/suites/main/avalanche-icarl.json | 9 +++ .../configs/suites/main/avalanche-lwf.json | 9 +++ .../configs/suites/main/cls-er.json | 9 +++ .../configs/suites/main/der.json | 9 +++ .../configs/suites/main/er.json | 9 +++ .../configs/suites/main/gdumb.json | 9 +++ .../configs/suites/main/joint.json | 9 +++ .../configs/suites/main/offline-er.json | 10 ++++ .../configs/suites/main/pod-er.json | 9 +++ .../configs/suites/main/rd.json | 9 +++ .../configs/suites/main/super-er.json | 9 +++ .../configs/updaters/rd.json | 9 +++ test/integration_tests/run_experiment.py | 8 ++- test/integration_tests/run_quick_test.py | 1 + test/integration_tests/run_test.py | 55 +++++++++++++++++++ 23 files changed, 227 insertions(+), 3 deletions(-) create mode 100644 test/integration_tests/configs/scenarios/class-incremental-5updates.json create mode 100644 test/integration_tests/configs/scenarios/feature-sorting-5updates.json create mode 100644 test/integration_tests/configs/scenarios/hue-shift-5updates.json create mode 100644 test/integration_tests/configs/scenarios/iid-5updates.json create mode 100644 test/integration_tests/configs/scenarios/permutation-10updates.json create mode 100644 test/integration_tests/configs/suites/main/avalanche-er.json create mode 100644 test/integration_tests/configs/suites/main/avalanche-ewc.json create mode 100644 test/integration_tests/configs/suites/main/avalanche-icarl.json create mode 100644 test/integration_tests/configs/suites/main/avalanche-lwf.json create mode 100644 test/integration_tests/configs/suites/main/cls-er.json create mode 100644 test/integration_tests/configs/suites/main/der.json create mode 100644 test/integration_tests/configs/suites/main/er.json create mode 100644 test/integration_tests/configs/suites/main/gdumb.json create mode 100644 test/integration_tests/configs/suites/main/joint.json create mode 100644 test/integration_tests/configs/suites/main/offline-er.json create mode 100644 test/integration_tests/configs/suites/main/pod-er.json create mode 100644 test/integration_tests/configs/suites/main/rd.json create mode 100644 test/integration_tests/configs/suites/main/super-er.json create mode 100644 test/integration_tests/configs/updaters/rd.json create mode 100644 test/integration_tests/run_test.py diff --git a/test/integration_tests/configs/scenarios/class-incremental-5updates.json b/test/integration_tests/configs/scenarios/class-incremental-5updates.json new file mode 100644 index 00000000..cf1879ab --- /dev/null +++ b/test/integration_tests/configs/scenarios/class-incremental-5updates.json @@ -0,0 +1,7 @@ +{ + "val_size": 0.05, + "scenario_name": "ClassIncrementalScenario", + "class_groupings": [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + "num_tasks": 5, + "max_epochs": 50 +} diff --git a/test/integration_tests/configs/scenarios/feature-sorting-5updates.json b/test/integration_tests/configs/scenarios/feature-sorting-5updates.json new file mode 100644 index 00000000..401fdbf5 --- /dev/null +++ b/test/integration_tests/configs/scenarios/feature-sorting-5updates.json @@ -0,0 +1,8 @@ +{ + "val_size": 0.05, + "scenario_name": "FeatureSortingScenario", + "feature_idx": 0, + "randomness": 0, + "num_tasks": 5, + "max_epochs": 50 +} diff --git a/test/integration_tests/configs/scenarios/hue-shift-5updates.json b/test/integration_tests/configs/scenarios/hue-shift-5updates.json new file mode 100644 index 00000000..216da549 --- /dev/null +++ b/test/integration_tests/configs/scenarios/hue-shift-5updates.json @@ -0,0 +1,8 @@ +{ + "val_size": 0.05, + "scenario_name": "HueShiftScenario", + "feature_idx": 0, + "randomness": 0, + "num_tasks": 5, + "max_epochs": 50 +} diff --git a/test/integration_tests/configs/scenarios/iid-5updates.json b/test/integration_tests/configs/scenarios/iid-5updates.json new file mode 100644 index 00000000..1820b8a9 --- /dev/null +++ b/test/integration_tests/configs/scenarios/iid-5updates.json @@ -0,0 +1,6 @@ +{ + "val_size": 0.05, + "scenario_name": "IIDScenario", + "num_tasks": 5, + "max_epochs": 50 +} diff --git a/test/integration_tests/configs/scenarios/permutation-10updates.json b/test/integration_tests/configs/scenarios/permutation-10updates.json new file mode 100644 index 00000000..86ff4cb5 --- /dev/null +++ b/test/integration_tests/configs/scenarios/permutation-10updates.json @@ -0,0 +1,7 @@ +{ + "val_size": 0.05, + "input_dim": [784], + "scenario_name": "PermutationScenario", + "num_tasks": 10, + "max_epochs": 15 +} diff --git a/test/integration_tests/configs/scenarios/rotation-10updates.json b/test/integration_tests/configs/scenarios/rotation-10updates.json index 33f3f41d..1865e0e5 100644 --- a/test/integration_tests/configs/scenarios/rotation-10updates.json +++ b/test/integration_tests/configs/scenarios/rotation-10updates.json @@ -2,5 +2,6 @@ "val_size": 0.05, "scenario_name": "ImageRotationScenario", "degrees": [0, 36, 72, 108, 144, 180, 216, 252, 288, 324], - "num_tasks": 10 + "num_tasks": 10, + "max_epochs": 15 } diff --git a/test/integration_tests/configs/suites/main/avalanche-er.json b/test/integration_tests/configs/suites/main/avalanche-er.json new file mode 100644 index 00000000..a7211272 --- /dev/null +++ b/test/integration_tests/configs/suites/main/avalanche-er.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "avalanche-er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-avalanche-er", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/avalanche-ewc.json b/test/integration_tests/configs/suites/main/avalanche-ewc.json new file mode 100644 index 00000000..384a4e0d --- /dev/null +++ b/test/integration_tests/configs/suites/main/avalanche-ewc.json @@ -0,0 +1,9 @@ +{ + "scenario": "rotation-10updates.json", + "model": "mlp-200.json", + "updater": "avalanche-ewc.json", + "dataset": "mnist.json", + "backend": "sagemaker", + "job_name": "rotation-mlp-avalanche-ewc", + "expected_accuracy": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/avalanche-icarl.json b/test/integration_tests/configs/suites/main/avalanche-icarl.json new file mode 100644 index 00000000..abd376b2 --- /dev/null +++ b/test/integration_tests/configs/suites/main/avalanche-icarl.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "avalanche-icarl-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-avalanche-icarl", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/avalanche-lwf.json b/test/integration_tests/configs/suites/main/avalanche-lwf.json new file mode 100644 index 00000000..c326acc4 --- /dev/null +++ b/test/integration_tests/configs/suites/main/avalanche-lwf.json @@ -0,0 +1,9 @@ +{ + "scenario": "permutation-10updates.json", + "model": "mlp-200.json", + "updater": "avalanche-lwf.json", + "dataset": "fashionmnist.json", + "backend": "sagemaker", + "job_name": "permutation-mlp-avalanche-lwf", + "expected_accuracy": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/cls-er.json b/test/integration_tests/configs/suites/main/cls-er.json new file mode 100644 index 00000000..412ae1ff --- /dev/null +++ b/test/integration_tests/configs/suites/main/cls-er.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "cls-er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-cls-er", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/der.json b/test/integration_tests/configs/suites/main/der.json new file mode 100644 index 00000000..0e6bd5c3 --- /dev/null +++ b/test/integration_tests/configs/suites/main/der.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "der-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-der", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/er.json b/test/integration_tests/configs/suites/main/er.json new file mode 100644 index 00000000..e6c024b8 --- /dev/null +++ b/test/integration_tests/configs/suites/main/er.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-er", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/gdumb.json b/test/integration_tests/configs/suites/main/gdumb.json new file mode 100644 index 00000000..04dffa2a --- /dev/null +++ b/test/integration_tests/configs/suites/main/gdumb.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "gdumb-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-gdumb", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/joint.json b/test/integration_tests/configs/suites/main/joint.json new file mode 100644 index 00000000..2ce95edf --- /dev/null +++ b/test/integration_tests/configs/suites/main/joint.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "joint.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-joint", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/offline-er.json b/test/integration_tests/configs/suites/main/offline-er.json new file mode 100644 index 00000000..9d305933 --- /dev/null +++ b/test/integration_tests/configs/suites/main/offline-er.json @@ -0,0 +1,10 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "offline-er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-offline-er", + "loss_weight_new_data": 0.5, + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/pod-er.json b/test/integration_tests/configs/suites/main/pod-er.json new file mode 100644 index 00000000..63b26c82 --- /dev/null +++ b/test/integration_tests/configs/suites/main/pod-er.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "pod-er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-pod-er", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/rd.json b/test/integration_tests/configs/suites/main/rd.json new file mode 100644 index 00000000..cf3c92e8 --- /dev/null +++ b/test/integration_tests/configs/suites/main/rd.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "rd.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-rd", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/suites/main/super-er.json b/test/integration_tests/configs/suites/main/super-er.json new file mode 100644 index 00000000..443ee8a6 --- /dev/null +++ b/test/integration_tests/configs/suites/main/super-er.json @@ -0,0 +1,9 @@ +{ + "scenario": "class-incremental-5updates.json", + "model": "resnet18-cifar.json", + "updater": "super-er-buffer500.json", + "dataset": "cifar10.json", + "backend": "sagemaker", + "job_name": "class-incremental-resnet18-super-er", + "expected_accuracy": [0, 0, 0, 0, 0] +} diff --git a/test/integration_tests/configs/updaters/rd.json b/test/integration_tests/configs/updaters/rd.json new file mode 100644 index 00000000..4f55f984 --- /dev/null +++ b/test/integration_tests/configs/updaters/rd.json @@ -0,0 +1,9 @@ +{ + "updater": "RD", + "optimizer": "Adam", + "learning_rate": 0.01, + "momentum": 0.0, + "weight_decay": 0.0, + "batch_size": 256, + "memory_size": 500 +} diff --git a/test/integration_tests/run_experiment.py b/test/integration_tests/run_experiment.py index a4f4623a..c249171d 100644 --- a/test/integration_tests/run_experiment.py +++ b/test/integration_tests/run_experiment.py @@ -65,14 +65,18 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): ) if args.backend == "local": experiment_outputs_url = ( - Path("tmp") / "renate-integration-tests" / args.test_suite / args.job_name + Path("tmp") + / "renate-integration-tests" + / args.test_suite + / args.job_name + / str(args.seed) ) role = None else: AWS_ACCOUNT_ID = boto3.client("sts").get_caller_identity().get("Account") experiment_outputs_url = ( f"s3://sagemaker-us-west-2-{AWS_ACCOUNT_ID}/renate-integration-tests/" - f"{args.test_suite}/{args.job_name}/" + f"{args.test_suite}/{args.job_name}/{args.seed}" ) role = get_execution_role() execute_experiment_job( diff --git a/test/integration_tests/run_quick_test.py b/test/integration_tests/run_quick_test.py index 44029384..168f149f 100644 --- a/test/integration_tests/run_quick_test.py +++ b/test/integration_tests/run_quick_test.py @@ -58,6 +58,7 @@ / "renate-integration-tests" / test_suite / job_name + / "0" / "logs" / f"metrics_summary_update_{num_updates - 1}.csv" ) diff --git a/test/integration_tests/run_test.py b/test/integration_tests/run_test.py new file mode 100644 index 00000000..8bcc610f --- /dev/null +++ b/test/integration_tests/run_test.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import os +import subprocess +from pathlib import Path + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--test-file", + type=str, + required=True, + help="Test suite to run.", + ) + parser.add_argument( + "--seed", + type=int, + required=True, + help="Seed.", + ) + args = parser.parse_args() + test_suite = "main" + current_folder = Path(os.path.dirname(__file__)) + configs_folder = current_folder / "configs" + test_file = configs_folder / "suites" / test_suite / args.test_file + if not test_file.is_file(): + raise FileNotFoundError(f"Unknown test file '{test_file}'.") + with open(test_file) as f: + test_config = json.load(f) + job_name = f"{test_config['job_name']}" + process = subprocess.Popen( + [ + "python", + current_folder / "run_experiment.py", + "--scenario-file", + configs_folder / "scenarios" / test_config["scenario"], + "--model-file", + configs_folder / "models" / test_config["model"], + "--updater-file", + configs_folder / "updaters" / test_config["updater"], + "--dataset-file", + configs_folder / "datasets" / test_config["dataset"], + "--backend", + test_config["backend"], + "--job-name", + job_name, + "--test-suite", + test_suite, + "--seed", + str(args.seed), + ] + ) + process.wait() From 512a04c1c41ec72276c6e0219955ccc8f40e8f3d Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Thu, 1 Jun 2023 15:37:22 +0200 Subject: [PATCH 22/89] Add example of using Renate in your own script (#274) --- doc/examples/custom_training_script.rst | 44 ++++++++++++++++++++++++ doc/examples/index.rst | 3 +- examples/custom_training_script/train.py | 41 ++++++++++++++++++++++ 3 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 doc/examples/custom_training_script.rst create mode 100644 examples/custom_training_script/train.py diff --git a/doc/examples/custom_training_script.rst b/doc/examples/custom_training_script.rst new file mode 100644 index 00000000..5f756cd6 --- /dev/null +++ b/doc/examples/custom_training_script.rst @@ -0,0 +1,44 @@ +Using Renate in a Costum Training Script +**************************************** + +Usually, we use Renate by writing a :code:`renate_config.py` and launching training jobs via the +:py:func:`~renate.training.training.run_training_job` function. In this example, we demonstrate how +to write your own training script and use renate in a functional way. This can be useful, e.g., for +debugging new components. + +Here, we use Renate to fine-tune a pretrained Transformer model on a sequence classification +dataset. First, we create the model and a loss function. Since this is a static model, we simply +wrap it using the :py:class:`~renate.models.renate_module.RenateWrapper` class. Recall that loss +functions should produce one loss value per input example (:code:`reduction="none"` for PyTorch's +built-in losses), as explained in :ref:`getting_started/how_to_renate_config:Loss Definition`. + +.. literalinclude:: ../../examples/custom_training_script/train.py + :lines: 12-16 + +Next, we prepare the dataset on which we want to fine-tune the model. Here, we use the +:py:class:`~renate.benchmark.datasets.nlp_datasets.HuggingFaceTextDataModule` to load the +:code:`"imdb"` dataset from the Hugging Face hub. This will also take care of tokenization for us, +if we pass it the corresponding tokenizer. + +.. literalinclude:: ../../examples/custom_training_script/train.py + :lines: 19-24 + +Now we can instantiate a :py:class:`~renate.updaters.model_updater.ModelUpdater` to perform the +training. Since we just want to fine-tune the model on a single dataset here, we use the +:py:class:`~renate.updaters.experimental.fine_tuning.FineTuningModelUpdater`. We pass our +:code:`model` as well as training details, such as the optimizer to use and its hyperparameters. +The model updater also receives all options related to distributed training, as explained in +:ref:`examples/nlp_finetuning:Support for training large models`. +Once the model updater is created, we initiate the training by calling its +:py:meth:`~renate.updaters.model_updater.ModelUpdater.update` method and passing training and +(optionally) validation datasets. + +.. literalinclude:: ../../examples/custom_training_script/train.py + :lines: 27-37 + +Once the training is terminated, your model is ready to deploy. Here, we just save its weights for +later use using standard +`PyTorch functionality `_. + +.. literalinclude:: ../../examples/custom_training_script/train.py + :lines: 41- diff --git a/doc/examples/index.rst b/doc/examples/index.rst index ec04ca4d..1d626ccf 100644 --- a/doc/examples/index.rst +++ b/doc/examples/index.rst @@ -7,4 +7,5 @@ Examples train_mlp_locally train_classifier_sagemaker - nlp_finetuning \ No newline at end of file + nlp_finetuning + custom_training_script \ No newline at end of file diff --git a/examples/custom_training_script/train.py b/examples/custom_training_script/train.py new file mode 100644 index 00000000..026b67e5 --- /dev/null +++ b/examples/custom_training_script/train.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import torch +import transformers + +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule +from renate.models.renate_module import RenateWrapper +from renate.updaters.experimental.fine_tuning import FineTuningModelUpdater + + +### Create model. +transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained( + "distilbert-base-uncased", num_labels=2, return_dict=False +) +model = RenateWrapper(transformer_model) +loss_fn = torch.nn.CrossEntropyLoss(reduction="none") + +### Prepare data. +tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased") +data_module = HuggingFaceTextDataModule( + "data", dataset_name="rotten_tomatoes", tokenizer=tokenizer, val_size=0.2 +) +data_module.prepare_data() # For multi-GPU, call only on rank 0. +data_module.setup() + +### Instantiate renate ModelUpdater and run fine-tuning. +updater = FineTuningModelUpdater( + model, + loss_fn, + optimizer="Adam", + learning_rate=3e-4, + batch_size=32, + max_epochs=3, + input_state_folder=None, + output_state_folder="renate_output", +) +updater.update(data_module.train_data(), data_module.val_data()) + + +### Do something with model, e.g., save its weights for later use. +torch.save(model.state_dict(), "model_weights.pt") From e7cade80b35d03095f14839b20706a5804111ab3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:36:06 +0200 Subject: [PATCH 23/89] Bump pytest from 7.2.2 to 7.3.1 (#259) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 33cd2f27..5f41356a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dev = [ "wild-time-data==0.1.0", "torch>=1.10.0, <1.12.2", # PyTest Dependencies - "pytest==7.2.2", + "pytest==7.3.1", "pytest-cov==4.0.0", "pytest-helpers-namespace==2021.12.29", ] From ac1ba3deaee5775d191c0696ebff45a1c1610e20 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:38:17 +0200 Subject: [PATCH 24/89] Bump black from 23.1.0 to 23.3.0 (#255) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5f41356a..1940f76b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ benchmark = [ "wild-time-data==0.1.0", ] dev = [ - "black==23.1.0", + "black==23.3.0", "avalanche_lib==0.3.1", "wild-time-data==0.1.0", "torch>=1.10.0, <1.12.2", From f1ba74e6313ed7bc8cb58fc6b3dd8950550e7113 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:39:40 +0200 Subject: [PATCH 25/89] Bump pydata-sphinx-theme from 0.13.1 to 0.13.3 (#183) From 27481f42cafd8da39c3569f77dd0122c2b8cd8ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:41:22 +0200 Subject: [PATCH 26/89] Bump sphinx-copybutton from 0.5.1 to 0.5.2 (#256) --- doc/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 8d7193be..36fd8312 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,6 +1,6 @@ docutils==0.19 Sphinx==6.1.3 -sphinx-copybutton==0.5.1 +sphinx-copybutton==0.5.2 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.1 pydata-sphinx-theme==0.13.3 From fdf407aa6224973f3ca32eb59f02b6fafdd229ee Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 09:43:03 +0200 Subject: [PATCH 27/89] Bump sphinx-autodoc-typehints from 1.22.0 to 1.23.0 (#257) --- doc/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 36fd8312..03512847 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -4,7 +4,7 @@ sphinx-copybutton==0.5.2 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.1 pydata-sphinx-theme==0.13.3 -sphinx-autodoc-typehints==1.22.0 +sphinx-autodoc-typehints==1.23.0 sphinx-paramlinks==0.5.4 # Temporarily added From 863c96c975076e4b32e28ad3b09118fdac25cb88 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:02:59 +0200 Subject: [PATCH 28/89] Bump pytest-cov from 4.0.0 to 4.1.0 (#266) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1940f76b..94835147 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dev = [ "torch>=1.10.0, <1.12.2", # PyTest Dependencies "pytest==7.3.1", - "pytest-cov==4.0.0", + "pytest-cov==4.1.0", "pytest-helpers-namespace==2021.12.29", ] From 9e3ac39fd991d8cfbe0a6a77ca334d683b16d0f5 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:13:53 +0200 Subject: [PATCH 29/89] Bump sphinxext-opengraph from 0.8.1 to 0.8.2 (#253) --- doc/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 03512847..e16ca83b 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -2,7 +2,7 @@ docutils==0.19 Sphinx==6.1.3 sphinx-copybutton==0.5.2 sphinx-hoverxref==1.3.0 -sphinxext-opengraph==0.8.1 +sphinxext-opengraph==0.8.2 pydata-sphinx-theme==0.13.3 sphinx-autodoc-typehints==1.23.0 sphinx-paramlinks==0.5.4 From 37c870989061a181154c491905a97cb2580b442d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 2 Jun 2023 10:15:01 +0200 Subject: [PATCH 30/89] Update pandas requirement from <2.0.2,>=1.4.0 to >=1.4.0,<2.0.3 (#267) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f7c22ed9..5f4ec962 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy>=1.17.2, <1.24.4 torch>=1.10.0, <1.13.2 -pandas>=1.4.0, <2.0.2 +pandas>=1.4.0, <2.0.3 boto3>=1.26.0, <1.26.139 requests>=2.31.0, <2.31.1 sagemaker>=2.112.0, <2.158.1 From e5c23d5640be27470ce4d3a310923300fae36ed8 Mon Sep 17 00:00:00 2001 From: Prabhu Teja S Date: Fri, 2 Jun 2023 18:14:30 +0200 Subject: [PATCH 31/89] Using `num_gpus_per_trial` after SyneTune update (#278) Co-authored-by: Prabhu Teja --- src/renate/training/training.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 3f08a5a4..28011ae3 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -133,7 +133,7 @@ def run_training_job( scheduler_kwargs: Only required if custom scheduler is provided. seed: Seed used for ensuring reproducibility. accelerator: Type of accelerator to use. - devices: Number of devices to use. + devices: Number of devices to use per worker (set in n_workers). strategy: Name of the distributed training strategy to use. precision: Type of bit precision to use. deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. @@ -562,9 +562,8 @@ def _execute_training_and_tuning_job_locally( f"Tuning hyperparameters with respect to {metric} ({mode}) for {max_time} seconds on " f"{n_workers} worker(s)." ) - # TODO: After bumping up SyneTune >= 0.6, use the argument `num_gpus_per_trial`. - backend = LocalBackend(entry_point=training_script, rotate_gpus=False if devices > 1 else True) + backend = LocalBackend(entry_point=training_script, num_gpus_per_trial=devices) if scheduler is None or not tune_hyperparameters: if scheduler is not None: warnings.warn( From 8d1b4123f884db5f72ba4a9136be73a470d7620b Mon Sep 17 00:00:00 2001 From: Prabhu Teja S Date: Wed, 7 Jun 2023 17:15:45 +0200 Subject: [PATCH 32/89] Implementing a buffer that handles dataset elements of different sizes (#279) --- src/renate/data/datasets.py | 32 ++++++++++- src/renate/memory/buffer.py | 18 ++---- src/renate/memory/storage.py | 86 ++++++++++++++++++++++------ src/renate/updaters/model_updater.py | 4 +- test/renate/data/test_datasets.py | 15 ++++- test/renate/memory/test_storage.py | 79 +++++++++++++++++++------ 6 files changed, 182 insertions(+), 52 deletions(-) diff --git a/src/renate/data/datasets.py b/src/renate/data/datasets.py index d3de8483..79c8e907 100644 --- a/src/renate/data/datasets.py +++ b/src/renate/data/datasets.py @@ -1,7 +1,6 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Optional, Tuple -from typing import List +from typing import Any, Callable, List, Optional, Tuple, Union import torch from PIL import Image @@ -105,6 +104,35 @@ def __getitem__(self, idx: int) -> NestedTensors: return self._get(self._nested_tensors, idx) +class IndexedSubsetDataset(Dataset): + """A dataset wrapper to keep specified indexes of a dataset element. + + Subset is indexing rows of a (tensor-)dataset, whereas IndexedSubset keeps specified columns. + It currently handles Datasets whose elements are tuples. + + Args: + dataset: The dataset to wrap + indexes_to_keep: An list or tuple of indices that are to be retained. + """ + + def __init__(self, dataset: Dataset, indexes_to_keep: Union[List, Tuple, int]) -> None: + self.dataset = dataset + if isinstance(indexes_to_keep, int): + indexes_to_keep = [indexes_to_keep] + self.indexes_to_keep = set(indexes_to_keep) + + def __getitem__(self, index) -> Any: + curr_item = self.dataset[index] + # Special handling if indexes_to_keep is a single int + if len(ret_val := [ci for i, ci in enumerate(curr_item) if i in self.indexes_to_keep]) == 1: + return ret_val[0] + else: + return tuple(ret_val) + + def __len__(self): + return len(self.dataset) + + class _TransformedDataset(Dataset): """A dataset wrapper that applies transformations. diff --git a/src/renate/memory/buffer.py b/src/renate/memory/buffer.py index f20480d9..7b762a27 100644 --- a/src/renate/memory/buffer.py +++ b/src/renate/memory/buffer.py @@ -8,7 +8,8 @@ from torch.utils.data import Dataset from renate import defaults -from renate.memory.storage import MemoryMappedTensorStorage +from renate.data.datasets import IndexedSubsetDataset +from renate.memory.storage import FileTensorStorage from renate.types import NestedTensors from renate.utils.pytorch import get_generator @@ -157,13 +158,8 @@ def save(self, target_dir: str) -> None: transforms = self._transform, self._target_transform self._transform, self._target_transform = None, None - storage = MemoryMappedTensorStorage( - target_dir, - data_point=self._data_point_prototype, - length=len(self), - ) - for i in range(len(self)): - storage[i] = self[i][0] # Drop metadata. + storage = FileTensorStorage(target_dir) + storage.dump_dataset(IndexedSubsetDataset(self, [0])) self._datasets = [storage] self._indices = {i: (0, i) for i in range(len(self))} self._transform, self._target_transform = transforms @@ -172,11 +168,7 @@ def load(self, source_dir: str) -> None: if not len(self): return - storage = MemoryMappedTensorStorage( - source_dir, - data_point=self._data_point_prototype, - length=len(self), - ) + storage = FileTensorStorage(source_dir) self._datasets = [storage] def _add_metadata_like(self, metadata: DataDict): diff --git a/src/renate/memory/storage.py b/src/renate/memory/storage.py index 05827d40..42f15819 100644 --- a/src/renate/memory/storage.py +++ b/src/renate/memory/storage.py @@ -2,26 +2,34 @@ # SPDX-License-Identifier: Apache-2.0 import math import os -from typing import Any, Tuple +from typing import Any, Tuple, Union, Optional +from pathlib import Path +from warnings import warn import torch from renate.types import NestedTensors -def mmap_tensor(filename: str, size: Tuple[int, ...], dtype: torch.dtype) -> torch.Tensor: +def mmap_tensor( + filename: str, size: Union[int, Tuple[int, ...]], dtype: torch.dtype +) -> torch.Tensor: """Creates or accesses a memory-mapped tensor.""" - t = torch.from_file(filename, shared=True, size=math.prod(size), dtype=dtype, device="cpu") + t = torch.from_file( + filename, + shared=True, + size=math.prod(size) if isinstance(size, tuple) else size, + dtype=dtype, + device="cpu", + ) return t.view(size) class Storage(torch.utils.data.Dataset): """An abstract class for permanent storage of datasets.""" - def __init__(self, directory: str, data_point: Any, length: int) -> None: + def __init__(self, directory: str) -> None: self._directory = directory - self._data_point = data_point - self._length = length def __len__(self) -> int: return self._length @@ -29,7 +37,10 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> Any: raise NotImplementedError() - def __setitem__(self, idx: int, data_point: Any) -> None: + def dump_dataset(self, ds: torch.utils.data.Dataset) -> None: + raise NotImplementedError() + + def load_dataset(self, directory: Union[str, Path]): raise NotImplementedError() @@ -38,7 +49,7 @@ class MemoryMappedTensorStorage(Storage): This implements storage for `length` data points consisting of nested tensors of fixed types and shapes. `Storage` implements `__len__` and `__getitem__` and therefore can be used as a - torch `Datasets`. To populate the storage, it also implements `__setitem__`. It does _not_ keep + torch `Dataset`. To populate the storage, it also implements `dump_dataset`. It does _not_ keep track which slots have or have not been populated. `Storage` is given a path to a directory, where it creates (or accesses, if they already exist) @@ -50,9 +61,16 @@ class MemoryMappedTensorStorage(Storage): length: Number of items to be stored. """ - def __init__(self, directory: str, data_point: NestedTensors, length: int) -> None: - super().__init__(directory, data_point, length) - self._storage = self._create_mmap_tensors(directory, data_point, length) + def __init__(self, directory: str) -> None: + warn( + f"""{self.__class__.__name__} will be deprecated very soon. Use FileTensorStorage + instead. {self.__class__.__name__} is currently not fully functional, as some of the + necessary parts of the interface have been modified and simplified. """, + DeprecationWarning, + stacklevel=2, + ) + super().__init__(directory) + self._storage: Optional[NestedTensors] = None @staticmethod def _create_mmap_tensors(path: str, data_point: NestedTensors, length: int) -> NestedTensors: @@ -63,14 +81,14 @@ def _create_mmap_tensors(path: str, data_point: NestedTensors, length: int) -> N elif isinstance(data_point, tuple): return tuple( MemoryMappedTensorStorage._create_mmap_tensors( - os.path.join(path, f"{i}"), data_point[i], length + os.path.join(path, f"{i}.pt"), data_point[i], length ) for i in range(len(data_point)) ) elif isinstance(data_point, dict): return { key: MemoryMappedTensorStorage._create_mmap_tensors( - os.path.join(path, key), data_point[key], length + os.path.join(path, f"{key}.pt"), data_point[key], length ) for key in data_point } @@ -111,6 +129,42 @@ def _set(storage: NestedTensors, idx: int, data_point: NestedTensors) -> None: else: raise TypeError(f"Expected nested tuple/dict of tensors, found {type(storage)}.") - def __setitem__(self, idx: int, data_point: NestedTensors) -> None: - """Set the item stored at index `idx`.""" - self._set(self._storage, idx, data_point) + def dump_dataset(self, ds): + self._length = len(ds) + self._storage = self._create_mmap_tensors(self._directory, ds[0], self._length) + for idx in range(len(self)): + self._set(self._storage, idx, ds[idx]) + + +class FileTensorStorage(Storage): + """A class implementing permanent storage of nested tensor datasets to disk as pickle files. + + This implements storage for `length` data points consisting of nested tensors of fixed types + and shapes. `Storage` implements `__len__` and `__getitem__` and therefore can be used as a + torch `Dataset`. To populate the storage, it also implements `dump_dataset`. It does _not_ keep + track which slots have or have not been populated. + + `Storage` is given a path to a directory, where it creates (or accesses, if they already exist) + pickle files one for each point in the dataset. + + Args: + directory: Path to a directory. + """ + + def __init__(self, directory: str) -> None: + super().__init__(directory) + + def dump_dataset(self, ds: torch.utils.data.Dataset) -> None: + for i in range(len(ds)): + torch.save(ds[i], self._compose_file_path_from_index(i)) + + def __getitem__(self, idx: int) -> Any: + if not hasattr(self, "_length"): + self.load_dataset(None) + return torch.load(self._compose_file_path_from_index(idx)) + + def load_dataset(self, directory: Union[str, Path]): + self._length = len([x for x in os.listdir(self._directory) if x.endswith(".pt")]) + + def _compose_file_path_from_index(self, idx: int) -> str: + return os.path.join(self._directory, f"{idx}.pt") diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index e086cc7a..821601e6 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -141,7 +141,9 @@ def _load_best_checkpoint_and_save(self, trainer: Trainer, pl_module: LightningM # Finalize model update. pl_module.on_model_update_end() # Save permanently. - pl_module.save(self._output_state_folder) + if trainer.is_global_zero: + # Save the buffer only on rank zero. + pl_module.save(self._output_state_folder) # Overwrite checkpoint. self._save_checkpoint(trainer, learner_state_path) diff --git a/test/renate/data/test_datasets.py b/test/renate/data/test_datasets.py index 3f1b574d..95f96a83 100644 --- a/test/renate/data/test_datasets.py +++ b/test/renate/data/test_datasets.py @@ -10,7 +10,7 @@ from torch.utils.data import TensorDataset from renate.data import ImageDataset -from renate.data.datasets import _EnumeratedDataset, _TransformedDataset +from renate.data.datasets import _EnumeratedDataset, _TransformedDataset, IndexedSubsetDataset class MulTransform: @@ -58,3 +58,16 @@ def test_image_dataset(tmpdir): data, target = dataset[0] assert torch.equal(data, torch.ones(3, 10, 10) * 2) assert torch.equal(target, torch.tensor(3 * 3)) + + +@pytest.mark.parametrize("indexes_to_keep", [1, 0, [0, 1], (0, 1)]) +def test_indexed_dataset(indexes_to_keep): + X = [torch.arange(10), torch.arange(20)[::2]] + ds = TensorDataset(*X) + subset = IndexedSubsetDataset(ds, indexes_to_keep=indexes_to_keep) + if isinstance(indexes_to_keep, (list, tuple)): + curr_x = torch.vstack([X[ind] for ind in indexes_to_keep]).T + else: + curr_x = X[indexes_to_keep] + ds_elements = torch.tensor([subset[i] for i in range(len(subset))]) + assert torch.equal(curr_x, ds_elements) diff --git a/test/renate/memory/test_storage.py b/test/renate/memory/test_storage.py index e6fd64c3..1dc6844d 100644 --- a/test/renate/memory/test_storage.py +++ b/test/renate/memory/test_storage.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import torch +from torch.utils.data import Dataset -from renate.memory.storage import MemoryMappedTensorStorage +from renate.memory.storage import FileTensorStorage def nested_tensors_equal(t1, t2): @@ -11,7 +12,7 @@ def nested_tensors_equal(t1, t2): return False if isinstance(t1, torch.Tensor): return torch.equal(t1, t2) - if isinstance(t1, tuple): + if isinstance(t1, (tuple, list)): if len(t1) != len(t2): return False return all(nested_tensors_equal(t1_, t2_) for t1_, t2_ in zip(t1, t2)) @@ -21,25 +22,65 @@ def nested_tensors_equal(t1, t2): return all(nested_tensors_equal(t1[key], t2[key]) for key in t1.keys()) -@pytest.mark.parametrize("length", [0, 1, 10]) +def make_dataset_same_sizes(dataset_length, return_type): + class CustomDataset(Dataset): + def __init__(self, x, y, return_type="tensor") -> None: + self.x, self.y = x, y + self.return_type = return_type + + def __getitem__(self, index): + if self.return_type == "tensor": + return self.x[index] + elif self.return_type == "tuple": + return (self.x[index], self.y[index]) + elif self.return_type == "dict": + return {"x": self.x[index], "y": self.y[index]} + elif self.return_type == "list": + return [self.x[index], self.y[index]] + + def __len__(self): + return self.x.shape[0] + + data_x = torch.rand(dataset_length, 3, 32, 32) + data_y = torch.randint_like(data_x, 5) + + return CustomDataset(data_x, data_y, return_type=return_type) + + +def make_dataset_different_sizes(dataset_length, return_type): + class CustomDataset(Dataset): + def __init__(self, return_type="tensor") -> None: + self.return_type = return_type + + def __getitem__(self, index): + x = torch.ones(index + 2, index + 3).float() * (index + 1) + y = torch.ones(index + 2, index + 3).int() * (index + 2) + if self.return_type == "tensor": + return x + elif self.return_type == "tuple": + return (x, y) + elif self.return_type == "dict": + return {"x": x, "y": y} + elif self.return_type == "list": + return [x, y] + + def __len__(self): + return dataset_length + + return CustomDataset(return_type=return_type) + + +@pytest.mark.parametrize("length", [1, 10]) +@pytest.mark.parametrize("return_type", ["tuple", "dict", "tensor", "list"]) @pytest.mark.parametrize( - "data_point", - [ - torch.tensor(1), - torch.ones(3), - (torch.ones(3), torch.tensor(1)), - ({"a": torch.ones(2, 3, 3), "b": torch.zeros(2)}, torch.tensor(4)), - {"a": (torch.ones(2, 3, 3), torch.zeros(2)), "b": torch.tensor(2)}, - ], + "dataset_maker_fn", [make_dataset_same_sizes, make_dataset_different_sizes] ) -def test_storage(tmpdir, length, data_point): - """Tests the memory-mapped tensor storage for different nested tensor structures.""" - storage = MemoryMappedTensorStorage(tmpdir, data_point, length) - for i in range(length): - storage[i] = data_point +def test_memory_storage_different_sizes(tmpdir, length, return_type, dataset_maker_fn): + ds = dataset_maker_fn(length, return_type) + storage = FileTensorStorage(tmpdir) + storage.dump_dataset(ds) del storage - storage = MemoryMappedTensorStorage(tmpdir, data_point, length) - + storage = FileTensorStorage(tmpdir) for i in range(length): - assert nested_tensors_equal(storage[i], data_point) + assert nested_tensors_equal(storage[i], ds[i]) From 380a461d2dd2c2e32c74d024852d6c4f1ccff89d Mon Sep 17 00:00:00 2001 From: Prabhu Teja S Date: Wed, 7 Jun 2023 18:33:20 +0200 Subject: [PATCH 33/89] Adding Clear dataset to renate.benchmarks (#287) --- src/renate/benchmark/datasets/vision_datasets.py | 2 ++ src/renate/benchmark/experiment_config.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 55acf107..c2b3ed46 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -127,6 +127,8 @@ class TorchVisionDataModule(RenateDataModule): }, "FashionMNIST": {"mean": 0.2860405969887955, "std": 0.3530242445149223}, "MNIST": {"mean": 0.1306604762738429, "std": 0.30810780385646264}, + "CLEAR10": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, + "CLEAR100": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, } def __init__( diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 0a185d12..4cb3ca4d 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -306,6 +306,14 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]: _get_normalize_transform(dataset_name), ] ) + if dataset_name in ["CLEAR10", "CLEAR100"]: + return transforms.Compose( + [ + transforms.Resize(224), + transforms.RandomCrop(224), + _get_normalize_transform(dataset_name), + ] + ) raise ValueError(f"Unknown dataset `{dataset_name}`.") @@ -318,4 +326,12 @@ def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: return None if dataset_name in ["CIFAR10", "CIFAR100"]: return _get_normalize_transform(dataset_name) + if dataset_name in ["CLEAR10", "CLEAR100"]: + return transforms.Compose( + [ + transforms.Resize(224), + transforms.CenterCrop(224), + _get_normalize_transform(dataset_name), + ] + ) raise ValueError(f"Unknown dataset `{dataset_name}`.") From ca6f8e700d317c37892bd486b1eac4278def4906 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 8 Jun 2023 15:38:26 +0200 Subject: [PATCH 34/89] Upload custom files and folders with a SageMaker training Job (#286) --- src/renate/benchmark/experimentation.py | 1 - src/renate/training/training.py | 12 ++++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 77bd0144..213a8527 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -392,7 +392,6 @@ def _execute_experiment_job_remotely(experiment_outputs_url: str, **job_kwargs: experiment_outputs_url ), f"experiment_outputs_url {experiment_outputs_url} is not on S3." return submit_remote_job( - source_dir=None, experiment_outputs_url=experiment_outputs_url, optional_dependencies="benchmark", **job_kwargs, diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 28011ae3..3e33917d 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -72,7 +72,7 @@ def run_training_job( input_state_url: Optional[str] = None, output_state_url: Optional[str] = None, working_directory: Optional[str] = defaults.WORKING_DIRECTORY, - source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, config_file: Optional[str] = None, requirements_file: Optional[str] = None, role: Optional[str] = None, @@ -111,7 +111,8 @@ def run_training_job( input_state_url: Path to the Renate model state. output_state_url: Path where Renate model state will be stored. working_directory: Path to the working directory. - source_dir: (SageMaker backend only) Root directory which will be moved to SageMaker. + dependencies: (SageMaker backend only) List of strings containing absolute or relative paths + to files and directories that will be uploaded as part of the SageMaker training job. config_file: File containing the definition of `model_fn` and `data_module_fn`. requirements_file: (SageMaker backend only) Path to requirements.txt containing environment dependencies. @@ -188,7 +189,7 @@ def run_training_job( max_epochs=max_epochs, task_id=task_id, chunk_id=chunk_id, - source_dir=source_dir, + dependencies=dependencies or [], requirements_file=requirements_file, role=role, instance_type=instance_type, @@ -625,7 +626,7 @@ def _execute_training_and_tuning_job_locally( def submit_remote_job( - source_dir: Union[str, None], + dependencies: List[str], role: str, instance_type: str, instance_count: int, @@ -641,12 +642,11 @@ def submit_remote_job( job_timestamp = defaults.current_timestamp() job_name = f"{job_name}-{job_timestamp}" tmp_dir = tempfile.mkdtemp() - dependencies = _prepare_remote_job( + dependencies += _prepare_remote_job( tmp_dir=tmp_dir, optional_dependencies=optional_dependencies, **job_kwargs ) PyTorch( entry_point=tuning_script, - source_dir=None if source_dir is None else str(source_dir), instance_type=instance_type, instance_count=instance_count, py_version=defaults.PYTHON_VERSION, From f540367a29060585c2102f238a25fc710818e825 Mon Sep 17 00:00:00 2001 From: Wes Kendrick Date: Fri, 9 Jun 2023 13:28:40 +0200 Subject: [PATCH 35/89] Run sagemaker tests from GitHub Actions (#275) --- .github/workflows/sagemaker_tests.yml | 69 +++++++++++++++++++ .../generate_requirements.py | 31 +++++++++ test/integration_tests/run_experiment.py | 13 ++++ test/integration_tests/run_test.py | 11 +++ 4 files changed, 124 insertions(+) create mode 100644 .github/workflows/sagemaker_tests.yml create mode 100644 test/integration_tests/generate_requirements.py diff --git a/.github/workflows/sagemaker_tests.yml b/.github/workflows/sagemaker_tests.yml new file mode 100644 index 00000000..c576bcd6 --- /dev/null +++ b/.github/workflows/sagemaker_tests.yml @@ -0,0 +1,69 @@ +name: SageMaker Tests + +# All tests in this file: +# 1. Run when launched manually +# 2. Require AWS credentials and launch training jobs + +on: + workflow_dispatch: +# pull_request: # UNCOMMENT FOR TESTING +# branches: # UNCOMMENT FOR TESTING +# - main # UNCOMMENT FOR TESTING +# - dev # UNCOMMENT FOR TESTING + +env: + AWS_DEFAULT_REGION: us-west-2 + AWS_ROLE: ${{ secrets.PROD_AWS_END_TO_END_TEST_ROLE_ARN }} + +permissions: + id-token: write # This is required for requesting the JWT + contents: read # This is required for actions/checkout + +jobs: + launch-sagemaker-jobs: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: 3.9 + cache: 'pip' + - name: Install Renate + run: | + python -m pip install --upgrade pip + python -m pip install -e '.[dev]' + - name: Install toml library + run: pip install toml + - name: Write requirements.txt for SageMaker training jobs + run: | + python test/integration_tests/generate_requirements.py + - name: Get Credentials + uses: aws-actions/configure-aws-credentials@v2 + with: + role-to-assume: ${{ secrets.PROD_AWS_END_TO_END_TEST_ROLE_ARN }} + role-session-name: integtestsession + aws-region: ${{ env.AWS_DEFAULT_REGION }} + - name: Launch SageMaker Jobs + run: | + import os + import subprocess + + target_directory = 'test/integration_tests/configs/suites/main' + files = [f for f in os.listdir(target_directory) if os.path.isfile(os.path.join(target_directory, f))] + + for file in files: + process = subprocess.Popen( + [ + "python", + "test/integration_tests/run_test.py", + "--test-file", + file, + "--seed", + "0", + "--requirements-file", + "test/integration_tests/requirements.txt", + ] + ) + process.wait() + + shell: python diff --git a/test/integration_tests/generate_requirements.py b/test/integration_tests/generate_requirements.py new file mode 100644 index 00000000..a7416e4d --- /dev/null +++ b/test/integration_tests/generate_requirements.py @@ -0,0 +1,31 @@ +import toml +from typing import List +import shutil + + +def generate_requirements_file_for_sagemaker_training_jobs( + toml_file: str, keys: List[str], output_path: str +): + dependencies = [] + + # create a copy of the main requirements.txt file + shutil.copyfile("./requirements.txt", "test/integration_tests/requirements.txt") + + # Parse the .toml file + with open(toml_file, "r") as file: + pyproject_toml = toml.load(file) + + for key in keys: + value = pyproject_toml["project"]["optional-dependencies"][key] + dependencies += value + + # Write the values into the output file + with open(f"{output_path}/requirements.txt", "a") as file: + for dependency in dependencies: + file.write(f"{dependency}\n") + + +if __name__ == "__main__": + generate_requirements_file_for_sagemaker_training_jobs( + "./pyproject.toml", ["avalanche", "benchmark", "dev"], "test/integration_tests" + ) diff --git a/test/integration_tests/run_experiment.py b/test/integration_tests/run_experiment.py index c249171d..b6ae22c3 100644 --- a/test/integration_tests/run_experiment.py +++ b/test/integration_tests/run_experiment.py @@ -3,6 +3,7 @@ import argparse import json from pathlib import Path +import os import boto3 from syne_tune.backend.sagemaker_backend.sagemaker_utils import get_execution_role @@ -59,10 +60,21 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): default=12 * 3600, help="Maximum execution time.", ) + parser.add_argument( + f"--requirements-file", + type=str, + required=False, + help="Path to requirements file", + ) args = parser.parse_args() config_space = load_config( args.scenario_file, args.model_file, args.updater_file, args.dataset_file ) + current_folder = Path(os.path.dirname(__file__)) + requirements_file = args.requirements_file + if not requirements_file: + requirements_file = current_folder.parent.parent / "requirements.txt" + if args.backend == "local": experiment_outputs_url = ( Path("tmp") @@ -94,4 +106,5 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): job_name=args.job_name[:36], devices=1, strategy="ddp", + requirements_file=args.requirements_file, ) diff --git a/test/integration_tests/run_test.py b/test/integration_tests/run_test.py index 8bcc610f..577d69d2 100644 --- a/test/integration_tests/run_test.py +++ b/test/integration_tests/run_test.py @@ -20,6 +20,12 @@ required=True, help="Seed.", ) + parser.add_argument( + f"--requirements-file", + type=str, + required=False, + help="Path to requirements file", + ) args = parser.parse_args() test_suite = "main" current_folder = Path(os.path.dirname(__file__)) @@ -30,6 +36,9 @@ with open(test_file) as f: test_config = json.load(f) job_name = f"{test_config['job_name']}" + requirements_file = args.requirements_file + if not requirements_file: + requirements_file = current_folder.parent.parent / "requirements.txt" process = subprocess.Popen( [ "python", @@ -50,6 +59,8 @@ test_suite, "--seed", str(args.seed), + "--requirements-file", + requirements_file, ] ) process.wait() From 93c30e81152ae155af7cdbe7c68a41dc5d2e3e45 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 14 Jun 2023 11:26:28 +0200 Subject: [PATCH 36/89] Fix Security Problem with `transformers` (#298) --- requirements.txt | 2 +- test/integration_tests/generate_requirements.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f31e6ef5..c35059b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ torchmetrics~=0.10.3 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 datasets>=2.9.0, <2.12.1 -transformers>4.23.0, <4.29.3 +transformers>=4.30.0, <4.30.2 scipy>=1.9.0, <1.10.2 diff --git a/test/integration_tests/generate_requirements.py b/test/integration_tests/generate_requirements.py index a7416e4d..a6ac319f 100644 --- a/test/integration_tests/generate_requirements.py +++ b/test/integration_tests/generate_requirements.py @@ -1,3 +1,5 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 import toml from typing import List import shutil From 15541cf01dc251c34ed9f238de72a32e894626b8 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 14 Jun 2023 16:33:28 +0200 Subject: [PATCH 37/89] Custom Optimizer and LR schedulers (#290) --- .github/workflows/run_unit_tests.yml | 2 +- doc/examples/custom_training_script.rst | 10 +- doc/getting_started/how_to_renate_config.rst | 32 ++- doc/getting_started/how_to_run_training.rst | 5 + doc/getting_started/shift_detection.rst | 8 +- examples/custom_training_script/train.py | 16 +- examples/getting_started/renate_config.py | 14 +- .../shift_detection/image_shift_detection.py | 3 +- src/renate/cli/parsing_functions.py | 94 ++++----- src/renate/cli/run_training.py | 29 ++- src/renate/defaults.py | 8 +- src/renate/memory/storage.py | 2 +- src/renate/shift/detector.py | 6 +- src/renate/shift/ks_detector.py | 2 +- .../updaters/avalanche/model_updater.py | 119 +++++------ src/renate/updaters/experimental/er.py | 113 ++++------- .../updaters/experimental/fine_tuning.py | 26 ++- src/renate/updaters/experimental/gdumb.py | 27 ++- src/renate/updaters/experimental/joint.py | 27 ++- .../updaters/experimental/offline_er.py | 33 ++-- .../updaters/experimental/repeated_distill.py | 81 ++------ src/renate/updaters/learner.py | 60 +++--- src/renate/updaters/model_updater.py | 21 +- src/renate/utils/deepspeed.py | 2 +- src/renate/utils/distributed_strategies.py | 6 +- src/renate/utils/module.py | 27 ++- src/renate/utils/optimizer.py | 44 +---- test/conftest.py | 187 ++++-------------- test/renate/data/test_datasets.py | 2 +- .../config_custom_optimizer.py | 43 ++++ test/renate/shift/test_detectors.py | 2 +- test/renate/training/test_run_training.py | 31 +-- .../avalanche/test_avalanche_learner.py | 10 +- .../avalanche/test_avalanche_model_updater.py | 11 +- test/renate/updaters/experimental/test_er.py | 35 ++-- .../updaters/experimental/test_fine_tuning.py | 1 + .../updaters/experimental/test_joint.py | 9 +- .../experimental/test_repeated_distill.py | 30 ++- test/renate/updaters/test_learner.py | 11 +- test/renate/updaters/test_model_updater.py | 5 +- .../utils/test_distributed_strategies.py | 2 +- test/renate/utils/test_optimizer.py | 46 +---- 42 files changed, 545 insertions(+), 697 deletions(-) create mode 100644 test/renate/renate_config_files/config_custom_optimizer.py diff --git a/.github/workflows/run_unit_tests.yml b/.github/workflows/run_unit_tests.yml index 73c56972..21a52301 100644 --- a/.github/workflows/run_unit_tests.yml +++ b/.github/workflows/run_unit_tests.yml @@ -42,7 +42,7 @@ jobs: - name: Black uses: psf/black@stable with: - version: "~= 22.12.0" + options: "--check --verbose" - name: Run unit tests run: PYTHONPATH=test python -m pytest test/ diff --git a/doc/examples/custom_training_script.rst b/doc/examples/custom_training_script.rst index 5f756cd6..7c509921 100644 --- a/doc/examples/custom_training_script.rst +++ b/doc/examples/custom_training_script.rst @@ -1,4 +1,4 @@ -Using Renate in a Costum Training Script +Using Renate in a Custom Training Script **************************************** Usually, we use Renate by writing a :code:`renate_config.py` and launching training jobs via the @@ -13,7 +13,7 @@ functions should produce one loss value per input example (:code:`reduction="non built-in losses), as explained in :ref:`getting_started/how_to_renate_config:Loss Definition`. .. literalinclude:: ../../examples/custom_training_script/train.py - :lines: 12-16 + :lines: 14-18 Next, we prepare the dataset on which we want to fine-tune the model. Here, we use the :py:class:`~renate.benchmark.datasets.nlp_datasets.HuggingFaceTextDataModule` to load the @@ -21,7 +21,7 @@ Next, we prepare the dataset on which we want to fine-tune the model. Here, we u if we pass it the corresponding tokenizer. .. literalinclude:: ../../examples/custom_training_script/train.py - :lines: 19-24 + :lines: 21-26 Now we can instantiate a :py:class:`~renate.updaters.model_updater.ModelUpdater` to perform the training. Since we just want to fine-tune the model on a single dataset here, we use the @@ -34,11 +34,11 @@ Once the model updater is created, we initiate the training by calling its (optionally) validation datasets. .. literalinclude:: ../../examples/custom_training_script/train.py - :lines: 27-37 + :lines: 29-39 Once the training is terminated, your model is ready to deploy. Here, we just save its weights for later use using standard `PyTorch functionality `_. .. literalinclude:: ../../examples/custom_training_script/train.py - :lines: 41- + :lines: 43 diff --git a/doc/getting_started/how_to_renate_config.rst b/doc/getting_started/how_to_renate_config.rst index 6dbb692a..4a1e3aa6 100644 --- a/doc/getting_started/how_to_renate_config.rst +++ b/doc/getting_started/how_to_renate_config.rst @@ -29,7 +29,7 @@ method, which automatically handles model hyperparameters. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 14-40 + :lines: 18-42 If you are using a torch model with **no or fixed hyperparameters**, you can use :py:class:`~renate.models.renate_module.RenateWrapper`. @@ -63,7 +63,7 @@ An example of this for the task of MNIST classification above as .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Loss function example - :lines: 95-96 + :lines: 99-100 Please note, loss functions should not be reduced. @@ -85,7 +85,29 @@ such as data subsampling or splitting. .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 41-68 + :lines: 45-72 + +Optimizer +========= + +Optimizers such as ``SGD`` or ``Adam`` can be selected by passing the corresponding arguments. +If you want to use other optimizers, you can do so by returning a partial optimizer object as +outlined in the example below. + +.. literalinclude:: ../../examples/getting_started/renate_config.py + :caption: Example + :lines: 103-104 + +Learning Rate Schedulers +======================== + +A learning rate scheduler can be provided by creating a function as demonstrated below. +This function will need to return a partial object of a learning rate scheduler as well as a string +that indicates whether the scheduler is updated after each ``epoch`` or after each ``step``. + +.. literalinclude:: ../../examples/getting_started/renate_config.py + :caption: Example + :lines: 107-108 Transforms ========== @@ -130,7 +152,7 @@ These are optional as well but, if omitted, Renate will use :code:`train_transfo .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 71-78 + :lines: 75-82 Custom Metrics ============== @@ -142,7 +164,7 @@ or created ad-hoc by implementing the same interface .. literalinclude:: ../../examples/getting_started/renate_config.py :caption: Example - :lines: 93- + :lines: 95-96 To enable the usage of additional metrics in Renate it is sufficient to implement the :code:`metrics_fn` function, returning a dictionary where the key is a string containing the diff --git a/doc/getting_started/how_to_run_training.rst b/doc/getting_started/how_to_run_training.rst index 73bbf6c9..221d509b 100644 --- a/doc/getting_started/how_to_run_training.rst +++ b/doc/getting_started/how_to_run_training.rst @@ -47,6 +47,11 @@ instantiating the method you selected. See :doc:`supported_algorithms` for more +.. note:: + If you have defined the ``optimizer_fn`` function in your Renate config, do not pass values for the keys + ``optimizer``, ``momentum``, ``weight_decay``, or ``learning_rate``, unless you have specified them as + :ref:`custom arguments `. + Once the configuration of the learning algorithm is specified, we need to set another couple of arguments in the :py:func:`~renate.training.training.run_training_job` function to make sure we obtain the desired behavior: diff --git a/doc/getting_started/shift_detection.rst b/doc/getting_started/shift_detection.rst index 03264483..543c7bf3 100644 --- a/doc/getting_started/shift_detection.rst +++ b/doc/getting_started/shift_detection.rst @@ -56,7 +56,7 @@ In practice, you would ingest your own data here, see the documentation for :py:class:`~renate.data.data_module.RenateDataModule`. .. literalinclude:: ../../examples/shift_detection/image_shift_detection.py - :lines: 13-16 + :lines: 12-15 For the purpose of this demonstration, we now generate a reference dataset as well as two query datasets: one from the same distribution, and one where we simulate a distribution shift by @@ -67,14 +67,14 @@ The query dataset would be the data you want to check for distribution shift, e. during the deployment of your model. .. literalinclude:: ../../examples/shift_detection/image_shift_detection.py - :lines: 22-26 + :lines: 21-25 Shift detection methods rely on informative (and relatively low-dimensional) features. Here, we use a pretrained ResNet model and chop of its output layer. This leads to 512-dimensional vectorial features. .. literalinclude:: ../../examples/shift_detection/image_shift_detection.py - :lines: 31-33 + :lines: 30-32 You can use any :py:class:`torch.nn.Module`, which may be a pretrained model or use a custom model that has been trained on the data at hand. @@ -85,7 +85,7 @@ Now we can instantiate an MMD-based shift detector. We first fit it to our refer then score both the in-distribution query dataset as well as the out-of-distribution query dataset. .. literalinclude:: ../../examples/shift_detection/image_shift_detection.py - :lines: 39-47 + :lines: 38-46 In this toy example, the shift is quite obvious and we will see a very high score for the out-of-distribution data:: diff --git a/examples/custom_training_script/train.py b/examples/custom_training_script/train.py index 026b67e5..a5831473 100644 --- a/examples/custom_training_script/train.py +++ b/examples/custom_training_script/train.py @@ -1,21 +1,23 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import torch import transformers +from torch.optim import Adam from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.models.renate_module import RenateWrapper from renate.updaters.experimental.fine_tuning import FineTuningModelUpdater - -### Create model. +# Create model. transformer_model = transformers.DistilBertForSequenceClassification.from_pretrained( "distilbert-base-uncased", num_labels=2, return_dict=False ) model = RenateWrapper(transformer_model) loss_fn = torch.nn.CrossEntropyLoss(reduction="none") -### Prepare data. +# Prepare data. tokenizer = transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased") data_module = HuggingFaceTextDataModule( "data", dataset_name="rotten_tomatoes", tokenizer=tokenizer, val_size=0.2 @@ -23,12 +25,12 @@ data_module.prepare_data() # For multi-GPU, call only on rank 0. data_module.setup() -### Instantiate renate ModelUpdater and run fine-tuning. +# Instantiate renate ModelUpdater and run fine-tuning. +optimizer = partial(Adam, learning_rate=3e-4) updater = FineTuningModelUpdater( model, loss_fn, - optimizer="Adam", - learning_rate=3e-4, + optimizer=optimizer, batch_size=32, max_epochs=3, input_state_folder=None, @@ -37,5 +39,5 @@ updater.update(data_module.train_data(), data_module.val_data()) -### Do something with model, e.g., save its weights for later use. +# Do something with model, e.g., save its weights for later use. torch.save(model.state_dict(), "model_weights.pt") diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index 756fc729..5045cb9d 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -1,9 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Optional +from functools import partial +from typing import Callable, Dict, Generator, Optional, Tuple import torch import torchvision +from torch.nn import Parameter +from torch.optim import AdamW, Optimizer +from torch.optim.lr_scheduler import StepLR, _LRScheduler from torchmetrics import Accuracy from torchvision.transforms import transforms @@ -94,3 +98,11 @@ def metrics_fn() -> Dict: def loss_fn() -> torch.nn.Module: return torch.nn.CrossEntropyLoss(reduction="none") + + +def optimizer_fn() -> Callable[[Generator[Parameter]], Optimizer]: + return partial(AdamW, lr=0.01, weight_decay=0.0) + + +def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]: + return partial(StepLR, step_size=10, gamma=0.1), "epoch" diff --git a/examples/shift_detection/image_shift_detection.py b/examples/shift_detection/image_shift_detection.py index f0b069fb..7a27c8ee 100644 --- a/examples/shift_detection/image_shift_detection.py +++ b/examples/shift_detection/image_shift_detection.py @@ -1,14 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import torch -from torchvision.models import resnet18, ResNet18_Weights +from torchvision.models import ResNet18_Weights, resnet18 from torchvision.transforms import GaussianBlur from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.data.datasets import _TransformedDataset from renate.shift.mmd_detectors import MMDCovariateShiftDetector - # Load CIFAR-10 training dataset. data_module = TorchVisionDataModule(data_path="data", dataset_name="CIFAR10", val_size=0.2) data_module.prepare_data() diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 381168f5..a3b55fee 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -4,11 +4,11 @@ import ast import inspect import sys -import pytorch_lightning as pl from importlib.util import find_spec from types import ModuleType from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +import pytorch_lightning as pl from syne_tune.optimizer.scheduler import TrialScheduler from renate import defaults @@ -48,17 +48,7 @@ def get_updater_and_learner_kwargs( """Returns the model updater class and the keyword arguments for the learner.""" if args.updater.startswith("Avalanche-") and find_spec("avalanche", None) is None: raise ImportError("Avalanche is not installed. Please run `pip install Renate[avalanche]`.") - learner_args = [ - "optimizer", - "learning_rate", - "learning_rate_scheduler", - "learning_rate_scheduler_step_size", - "learning_rate_scheduler_gamma", - "momentum", - "weight_decay", - "batch_size", - "seed", - ] + learner_args = ["batch_size", "seed"] base_er_args = learner_args + [ "loss_weight", "ema_memory_update_gamma", @@ -182,7 +172,7 @@ def parse_arguments( to all functions specified in ``function_names``. """ arguments = _standard_arguments() - _add_hyperparameter_arguments(arguments) + _add_hyperparameter_arguments(arguments, "optimizer_fn" not in vars(config_module)) function_args = {} for function_name in function_names: function_args[function_name] = get_function_args( @@ -336,7 +326,9 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: } -def _add_hyperparameter_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: +def _add_hyperparameter_arguments( + arguments: Dict[str, Dict[str, Any]], add_optimizer_args: bool +) -> None: """Adds arguments for the specified updater.""" updater: Optional[str] = None for i, arg in enumerate(sys.argv): @@ -348,56 +340,46 @@ def _add_hyperparameter_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: assert updater in parse_by_updater, f"Unknown updater {updater}." parse_by_updater[updater](arguments) - _add_optimizer_arguments(arguments) + _add_optimizer_arguments(arguments, add_optimizer_args) for value in arguments.values(): if "argument_group" not in value: value["argument_group"] = HYPERPARAMETER_ARGS_GROUP -def _add_optimizer_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: +def _add_optimizer_arguments( + arguments: Dict[str, Dict[str, Any]], add_optimizer_args: bool +) -> None: """A helper function that adds optimizer arguments.""" + if add_optimizer_args: + arguments.update( + { + "optimizer": { + "type": str, + "default": defaults.OPTIMIZER, + "help": "Optimizer used for training. Options: SGD or Adam. Default: " + f"{defaults.OPTIMIZER}.", + }, + "learning_rate": { + "type": float, + "default": defaults.LEARNING_RATE, + "help": "Learning rate used during model update. Default: " + f"{defaults.LEARNING_RATE}.", + }, + "momentum": { + "type": float, + "default": defaults.MOMENTUM, + "help": f"Momentum used during model update. Default: {defaults.MOMENTUM}.", + }, + "weight_decay": { + "type": float, + "default": defaults.WEIGHT_DECAY, + "help": "Weight decay used during model update. Default: " + f"{defaults.WEIGHT_DECAY}.", + }, + } + ) arguments.update( { - "optimizer": { - "type": str, - "default": defaults.OPTIMIZER, - "help": "Optimizer used for training. Options: SGD or Adam. Default: " - f"{defaults.OPTIMIZER}.", - }, - "learning_rate": { - "type": float, - "default": defaults.LEARNING_RATE, - "help": "Learning rate used during model update. Default: " - f"{defaults.LEARNING_RATE}.", - }, - "learning_rate_scheduler": { - "type": str, - "default": defaults.LEARNING_RATE_SCHEDULER, - "help": "Learning rate scheduler used during model update. Default: " - f"{defaults.LEARNING_RATE_SCHEDULER}.", - }, - "learning_rate_scheduler_step_size": { - "type": int, - "default": defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - "help": "Step size for learning rate scheduler. Default: " - f"{defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE}.", - }, - "learning_rate_scheduler_gamma": { - "type": float, - "default": defaults.LEARNING_RATE_SCHEDULER_GAMMA, - "help": "Gamma for learning rate scheduler. Default: " - f"{defaults.LEARNING_RATE_SCHEDULER_GAMMA}.", - }, - "momentum": { - "type": float, - "default": defaults.MOMENTUM, - "help": f"Momentum used during model update. Default: {defaults.MOMENTUM}.", - }, - "weight_decay": { - "type": float, - "default": defaults.WEIGHT_DECAY, - "help": f"Weight decay used during model update. Default: {defaults.WEIGHT_DECAY}.", - }, "batch_size": { "type": int, "default": defaults.BATCH_SIZE, diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index 56e38756..c1281beb 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -18,11 +18,14 @@ from renate.utils.file import maybe_download_from_s3, move_to_uri from renate.utils.module import ( get_and_setup_data_module, + get_learning_rate_scheduler, get_loss_fn, get_metrics, get_model, + get_optimizer, import_module, ) +from renate.utils.optimizer import create_partial_optimizer from renate.utils.syne_tune import redirect_to_tmp logger = logging.getLogger(__name__) @@ -99,9 +102,11 @@ def run(self): "train_transform", "test_transform", "buffer_transform", - "metrics_fn", "scheduler_fn", "loss_fn", + "optimizer_fn", + "lr_scheduler_fn", + "metrics_fn", ], ignore_args=["data_path", "model_state_url"], ) @@ -125,13 +130,32 @@ def run(self): not args.updater.startswith("Avalanche-"), **get_function_kwargs(args=args, function_args=function_args["loss_fn"]), ) - + partial_optimizer = get_optimizer( + config_module, + **get_function_kwargs(args=args, function_args=function_args["optimizer_fn"]), + ) + if partial_optimizer is None: + partial_optimizer = create_partial_optimizer( + optimizer=args.optimizer, + lr=args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay, + ) + lr_scheduler_config = get_learning_rate_scheduler( + config_module, + **get_function_kwargs(args=args, function_args=function_args["lr_scheduler_fn"]), + ) + lr_scheduler_kwargs = {} + if lr_scheduler_config is not None: + lr_scheduler_kwargs["learning_rate_scheduler"] = lr_scheduler_config[0] + lr_scheduler_kwargs["learning_rate_scheduler_interval"] = lr_scheduler_config[1] metrics = get_metrics(config_module) model_updater_class, learner_kwargs = get_updater_and_learner_kwargs(args) model_updater = model_updater_class( model=model, + optimizer=partial_optimizer, input_state_folder=self._input_state_folder, output_state_folder=self._output_state_folder, max_epochs=args.max_epochs, @@ -146,6 +170,7 @@ def run(self): deterministic_trainer=args.deterministic_trainer, loss_fn=loss_fn, **learner_kwargs, + **lr_scheduler_kwargs, **get_transforms_dict(config_module, args, function_args), ) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 14b3c0fb..c2852964 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -10,12 +10,10 @@ OPTIMIZER = "Adam" SUPPORTED_OPTIMIZERS = ["Adam", "SGD"] SUPPORTED_OPTIMIZERS_TYPE = Literal["Adam", "SGD"] +LR_SCHEDULER_INTERVAL = "epoch" +SUPPORTED_LR_SCHEDULER_INTERVAL = ["epoch", "step"] +SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = Literal["epoch", "step"] LEARNING_RATE = 3e-4 -LEARNING_RATE_SCHEDULER = "ConstantLR" -LEARNING_RATE_SCHEDULER_GAMMA = 1.0 -LEARNING_RATE_SCHEDULER_STEP_SIZE = 1 -SUPPORTED_LEARNING_RATE_SCHEDULERS = ["ConstantLR", "ExponentialLR", "StepLR"] -SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = Literal["ConstantLR", "ExponentialLR", "StepLR"] MOMENTUM = 0.0 WEIGHT_DECAY = 0.0 MAX_EPOCHS = 50 diff --git a/src/renate/memory/storage.py b/src/renate/memory/storage.py index 42f15819..3c2ebc96 100644 --- a/src/renate/memory/storage.py +++ b/src/renate/memory/storage.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import math import os -from typing import Any, Tuple, Union, Optional from pathlib import Path +from typing import Any, Optional, Tuple, Union from warnings import warn import torch diff --git a/src/renate/shift/detector.py b/src/renate/shift/detector.py index c40c3839..27cb08ee 100644 --- a/src/renate/shift/detector.py +++ b/src/renate/shift/detector.py @@ -1,10 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import torch -from torch.utils.data import Dataset, DataLoader - from typing import Optional +import torch +from torch.utils.data import DataLoader, Dataset + from renate.utils.pytorch import move_tensors_to_device diff --git a/src/renate/shift/ks_detector.py b/src/renate/shift/ks_detector.py index 484d94dd..b0b2b9a6 100644 --- a/src/renate/shift/ks_detector.py +++ b/src/renate/shift/ks_detector.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from scipy.stats import kstest import torch +from scipy.stats import kstest from renate.shift.detector import ShiftDetectorWithFeatureExtractor diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index 7a4f04f3..bd92bef0 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging from pathlib import Path -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type import torch import torchmetrics @@ -11,7 +11,9 @@ from avalanche.training.supervised.icarl import _ICaRLPlugin from avalanche.training.templates import BaseSGDTemplate from syne_tune import Reporter +from torch.nn import Parameter from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import Dataset from renate import defaults @@ -61,10 +63,21 @@ def _load_learner( logged_metrics=self._logged_metrics, **self._transforms_kwargs, ) - optimizer, scheduler = self._dummy_learner.configure_optimizers() - optimizer, scheduler = optimizer[0], scheduler[0] - lr_scheduler_plugin = LRSchedulerPlugin(scheduler=scheduler) - plugins = [lr_scheduler_plugin] + plugins = [] + lr_scheduler_plugin = None + optimizers_scheduler = self._dummy_learner.configure_optimizers() + if isinstance(optimizers_scheduler, tuple): + optimizer, scheduler_config = optimizers_scheduler + optimizer, scheduler_config = optimizer[0], scheduler_config[0] + lr_scheduler_plugin = LRSchedulerPlugin( + scheduler=scheduler_config["scheduler"], + step_granularity="iteration" + if scheduler_config["interval"] == "step" + else scheduler_config["interval"], + ) + plugins.append(lr_scheduler_plugin) + else: + optimizer = optimizers_scheduler avalanche_learner = self._load_if_exists(self._input_state_folder) checkpoint_plugin = None @@ -107,7 +120,7 @@ def _get_device(self) -> torch.device: def _create_avalanche_learner( self, checkpoint_plugin: RenateCheckpointPlugin, - lr_scheduler_plugin: LRSchedulerPlugin, + lr_scheduler_plugin: Optional[LRSchedulerPlugin], optimizer: Optimizer, ) -> BaseSGDTemplate: """Returns an Avalanche learner based on the arguments passed to the ModelUpdater. @@ -117,7 +130,9 @@ def _create_avalanche_learner( lr_scheduler_plugin: Plugin to adapt the learning rate. optimizer: PyTorch optimizer object used for training the Avalanche learner. """ - plugins = [lr_scheduler_plugin] + plugins = [] + if lr_scheduler_plugin is not None: + plugins.append(lr_scheduler_plugin) if checkpoint_plugin is not None: plugins.append(checkpoint_plugin) avalanche_learner = self._dummy_learner.create_avalanche_learner( @@ -224,15 +239,11 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -255,26 +266,22 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { + "batch_size": batch_size, "memory_size": memory_size, "memory_batch_size": memory_batch_size, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, - "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=AvalancheReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -297,14 +304,10 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], ewc_lambda: float = defaults.EWC_LAMBDA, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -327,25 +330,21 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "ewc_lambda": ewc_lambda, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=AvalancheEWCLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -368,15 +367,11 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], alpha: float = defaults.LWF_ALPHA, temperature: float = defaults.LWF_TEMPERATURE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -399,26 +394,22 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "alpha": alpha, "temperature": temperature, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=AvalancheLwFLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -441,15 +432,11 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -474,24 +461,20 @@ def __init__( learner_kwargs = { "memory_size": memory_size, "memory_batch_size": memory_batch_size, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=AvalancheICaRLLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index 7dd279a0..3005a18a 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import abc +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -8,6 +9,8 @@ import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset, Subset from renate import defaults @@ -502,19 +505,15 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, alpha: float = defaults.ER_ALPHA, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -544,24 +543,20 @@ def __init__( "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, "alpha": alpha, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=ExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -586,6 +581,7 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -593,13 +589,8 @@ def __init__( loss_normalization: int = defaults.LOSS_NORMALIZATION, alpha: float = defaults.DER_ALPHA, beta: float = defaults.DER_BETA, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -630,24 +621,20 @@ def __init__( "loss_normalization": loss_normalization, "alpha": alpha, "beta": beta, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=DarkExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -672,6 +659,7 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -680,13 +668,8 @@ def __init__( alpha: float = defaults.POD_ALPHA, distillation_type: str = defaults.POD_DISTILLATION_TYPE, normalize: bool = defaults.POD_NORMALIZE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -718,24 +701,20 @@ def __init__( "alpha": alpha, "distillation_type": distillation_type, "normalize": normalize, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=PooledOutputDistillationExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -760,6 +739,7 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -771,13 +751,8 @@ def __init__( plastic_model_update_weight: float = defaults.CLS_PLASTIC_MODEL_UPDATE_WEIGHT, stable_model_update_probability: float = defaults.CLS_STABLE_MODEL_UPDATE_PROBABILITY, plastic_model_update_probability: float = defaults.CLS_PLASTIC_MODEL_UPDATE_PROBABILITY, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -812,24 +787,20 @@ def __init__( "plastic_model_update_weight": plastic_model_update_weight, "stable_model_update_probability": stable_model_update_probability, "plastic_model_update_probability": plastic_model_update_probability, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=CLSExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -854,6 +825,7 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight: float = defaults.LOSS_WEIGHT, @@ -871,13 +843,8 @@ def __init__( pod_alpha: float = defaults.SER_POD_ALPHA, pod_distillation_type: str = defaults.SER_POD_DISTILLATION_TYPE, pod_normalize: bool = defaults.SER_POD_NORMALIZE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -918,24 +885,20 @@ def __init__( "pod_alpha": pod_alpha, "pod_distillation_type": pod_distillation_type, "pod_normalize": pod_normalize, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=SuperExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/fine_tuning.py b/src/renate/updaters/experimental/fine_tuning.py index 1a0db866..d9139269 100644 --- a/src/renate/updaters/experimental/fine_tuning.py +++ b/src/renate/updaters/experimental/fine_tuning.py @@ -1,10 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Dict, Optional +from functools import partial +from typing import Callable, Dict, List, Optional import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger +from torch.nn import Parameter +from torch.optim import Optimizer from renate import defaults from renate.models import RenateModule @@ -17,13 +20,9 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + optimizer: Callable[[List[Parameter]], Optimizer], + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -45,24 +44,21 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=Learner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index 3d1bf239..e978d023 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -1,11 +1,14 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from renate import defaults @@ -91,15 +94,11 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -125,24 +124,20 @@ def __init__( learner_kwargs = { "memory_size": memory_size, "memory_batch_size": memory_batch_size, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=GDumbLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index 0bcb9677..179b9667 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -1,12 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import os -from typing import Any, Callable, Dict, Optional, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from renate import defaults @@ -84,13 +87,9 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + optimizer: Callable[[List[Parameter]], Optimizer], + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -112,24 +111,20 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): learner_kwargs = { - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=JointLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 2c505b01..4836015d 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -1,12 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Callable, Dict, Optional, Tuple +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from renate import defaults @@ -39,11 +42,7 @@ class OfflineExperienceReplayLearner(ReplayLearner): samples. """ - def __init__( - self, - loss_weight_new_data: Optional[float] = None, - **kwargs, - ) -> None: + def __init__(self, loss_weight_new_data: Optional[float] = None, **kwargs) -> None: super().__init__(**kwargs) if loss_weight_new_data is not None and not (0.0 <= loss_weight_new_data <= 1.0): raise ValueError( @@ -138,16 +137,12 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, memory_batch_size: int = defaults.BATCH_SIZE, loss_weight_new_data: Optional[float] = None, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, @@ -174,24 +169,20 @@ def __init__( "memory_size": memory_size, "memory_batch_size": memory_batch_size, "loss_weight_new_data": loss_weight_new_data, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, "batch_size": batch_size, "seed": seed, - "loss_fn": loss_fn, } super().__init__( model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=OfflineExperienceReplayLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, diff --git a/src/renate/updaters/experimental/repeated_distill.py b/src/renate/updaters/experimental/repeated_distill.py index 0ee27b88..edfba727 100644 --- a/src/renate/updaters/experimental/repeated_distill.py +++ b/src/renate/updaters/experimental/repeated_distill.py @@ -1,12 +1,15 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Callable, Dict, Optional, Tuple +from functools import partial +from typing import Callable, Dict, List, Optional, Tuple import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer from torch.utils.data import DataLoader, Dataset from renate import defaults @@ -46,7 +49,7 @@ def extract_logits( model: torch.nn.Module, dataset: Dataset, batch_size: int, - task_id: int = defaults.TASK_ID, + task_id: Optional[str] = defaults.TASK_ID, ) -> torch.Tensor: """Extracts logits from a model for each point in a dataset. @@ -95,14 +98,10 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - optimizer: str = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: float = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, @@ -125,26 +124,18 @@ def __init__( early_stopping_enabled=False, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): - learner_kwargs = { - "memory_size": memory_size, - "optimizer": optimizer, - "learning_rate": learning_rate, - "learning_rate_scheduler": learning_rate_scheduler, - "learning_rate_scheduler_gamma": learning_rate_scheduler_gamma, - "learning_rate_scheduler_step_size": learning_rate_scheduler_step_size, - "momentum": momentum, - "weight_decay": weight_decay, - "batch_size": batch_size, - "seed": seed, - "loss_fn": loss_fn, - } + learner_kwargs = {"memory_size": memory_size, "batch_size": batch_size, "seed": seed} super().__init__( model=model, + loss_fn=loss_fn, + optimizer=optimizer, learner_class=RepeatedDistillationLearner, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, output_state_folder=output_state_folder, max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, train_transform=train_transform, train_target_transform=train_target_transform, test_transform=test_transform, @@ -211,50 +202,8 @@ def update( class RepeatedDistillationLearner(ReplayLearner): """A learner performing distillation.""" - def __init__( - self, - model: RenateModule, - loss_fn: torch.nn.Module, - memory_size: int, - optimizer: str = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: float = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, - batch_size: int = defaults.BATCH_SIZE, - train_transform: Optional[Callable] = None, - train_target_transform: Optional[Callable] = None, - test_transform: Optional[Callable] = None, - test_target_transform: Optional[Callable] = None, - buffer_transform: Optional[Callable] = None, - buffer_target_transform: Optional[Callable] = None, - logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, - seed: Optional[int] = None, - ) -> None: - super().__init__( - model=model, - memory_size=memory_size, - memory_batch_size=batch_size, - buffer_transform=buffer_transform, - buffer_target_transform=buffer_target_transform, - optimizer=optimizer, - learning_rate=learning_rate, - learning_rate_scheduler=learning_rate_scheduler, - learning_rate_scheduler_gamma=learning_rate_scheduler_gamma, - learning_rate_scheduler_step_size=learning_rate_scheduler_step_size, - momentum=momentum, - weight_decay=weight_decay, - batch_size=batch_size, - train_transform=train_transform, - train_target_transform=train_target_transform, - test_transform=test_transform, - test_target_transform=test_target_transform, - logged_metrics=logged_metrics, - seed=seed, - loss_fn=loss_fn, - ) + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) self._expert_logits: Optional[torch.Tensor] = None def update_expert_logits(self, new_expert_logits: torch.Tensor) -> None: diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index 3a5211fb..7c943f3e 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -10,6 +10,9 @@ from pytorch_lightning import LightningModule from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import Tensor +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader, Dataset from renate import defaults @@ -18,7 +21,6 @@ from renate.memory import DataBuffer, InfiniteBuffer, ReservoirBuffer from renate.models import RenateModule from renate.types import NestedTensors -from renate.utils.optimizer import create_optimizer, create_scheduler from renate.utils.pytorch import get_generator @@ -38,13 +40,11 @@ class Learner(LightningModule, abc.ABC): Args: model: The model to be trained. - optimizer: Optimizer used for training. Options: `Adam` or `SGD`. - learning_rate: Initial learning rate used for training. - learning_rate_scheduler: Learning rate scheduler used for training. - learning_rate_scheduler_gamma: Learning rate scheduler gamma. - learning_rate_scheduler_step_size: Learning rate scheduler step size. - momentum: Momentum term (only relevant for optimizer `SGD`). - weight_decay: L2 regularization applied to all model weights. + optimizer: Partial optimizer used to create an optimizer by passing the model parameters. + learning_rate_scheduler: Partial object of learning rate scheduler that will be created by + passing the optimizer. + learning_rate_scheduler_interval: When to update the learning rate scheduler. + Options: `epoch` and `step`. batch_size: Training batch size. train_transform: The transformation applied during training. train_target_transform: The target transformation applied during testing. @@ -58,13 +58,9 @@ def __init__( self, model: RenateModule, loss_fn: torch.nn.Module, - optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, - learning_rate: float = defaults.LEARNING_RATE, - learning_rate_scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, # noqa: E501 - learning_rate_scheduler_gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, - learning_rate_scheduler_step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - momentum: float = defaults.MOMENTUM, - weight_decay: float = defaults.WEIGHT_DECAY, + optimizer: Callable[[List[Parameter]], Optimizer], + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, @@ -77,12 +73,8 @@ def __init__( self._model = model self._loss_fn = loss_fn self._optimizer = optimizer - self._learning_rate = learning_rate self._learning_rate_scheduler = learning_rate_scheduler - self._learning_rate_scheduler_gamma = learning_rate_scheduler_gamma - self._learning_rate_scheduler_step_size = learning_rate_scheduler_step_size - self._momentum = momentum - self._weight_decay = weight_decay + self._learning_rate_scheduler_interval = learning_rate_scheduler_interval self._batch_size = batch_size self._train_transform = train_transform self._train_target_transform = train_target_transform @@ -98,6 +90,8 @@ def __init__( ignore=[ "model", "loss_fn", + "optimizer", + "learning_rate_scheduler", "components", "train_transform", "test_transform", @@ -274,22 +268,16 @@ def validation_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> def configure_optimizers( self, - ) -> Tuple[List[torch.optim.Optimizer], List[torch.optim.lr_scheduler._LRScheduler]]: - """PyTorch Lightning function to create an optimizer.""" - optimizer = create_optimizer( - params=self._model.get_params(self._task_id), - optimizer=self._optimizer, - lr=self._learning_rate, - momentum=self._momentum, - weight_decay=self._weight_decay, - ) - scheduler = create_scheduler( - scheduler=self._learning_rate_scheduler, - optimizer=optimizer, - gamma=self._learning_rate_scheduler_gamma, - step_size=self._learning_rate_scheduler_step_size, - ) - return [optimizer], [scheduler] + ) -> Union[Optimizer, Tuple[List[Optimizer], List[Dict[str, Any]]]]: + """PyTorch Lightning function to create optimizers and learning rate schedulers.""" + optimizer = self._optimizer(self._model.get_params(self._task_id)) + if self._learning_rate_scheduler is None: + return optimizer + lr_scheduler_config = { + "scheduler": self._learning_rate_scheduler(optimizer), + "interval": self._learning_rate_scheduler_interval, + } + return [optimizer], [lr_scheduler_config] def _update_metrics( self, diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index 821601e6..b16df0f0 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -13,6 +13,9 @@ from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.rank_zero import rank_zero_only from syne_tune import Reporter +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import Dataset from renate import defaults @@ -20,9 +23,8 @@ from renate.utils.distributed_strategies import create_strategy from renate.utils.file import unlink_file_or_folder from renate.utils.misc import int_or_str - -from ..models import RenateModule from .learner import Learner, ReplayLearner +from ..models import RenateModule logging_logger = logging.getLogger(__name__) @@ -155,7 +157,7 @@ def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> 1. If deepspeed is being used: The learner_state_path (which the checkpointing func) uses is a directory and not a file. - This directory has sharded state_dicts (of model and optimizers, depending on which + This directory has sharded state_dicts (of model and optimizers), depending on which deepspeed stage is used. There are three steps here a. combine all the shards into one big state dict. @@ -184,7 +186,7 @@ def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> torch.save(combined_state_dict, learner_state_path) self.teardown(trainer, pl_module, stage) elif learner_state_path.exists() and learner_state_path.is_file(): - ## This a normal file. We strip the model of any wrappers and save that. + # This a normal file. We strip the model of any wrappers and save that. state_dict = torch.load(learner_state_path)["state_dict"] out_sd = {k.replace("_model.", "", 1): v for k, v in state_dict.items()} # Replace only 1 instance because we have to load it into RenateModule. @@ -239,11 +241,15 @@ class ModelUpdater(abc.ABC): def __init__( self, model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], learner_class: Type[Learner], learner_kwargs: Optional[Dict[str, Any]] = None, input_state_folder: Optional[str] = None, output_state_folder: Optional[str] = None, max_epochs: int = defaults.MAX_EPOCHS, + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 train_transform: Optional[Callable] = None, train_target_transform: Optional[Callable] = None, test_transform: Optional[Callable] = None, @@ -262,6 +268,13 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, ): self._learner_kwargs = learner_kwargs or {} + self._learner_kwargs["loss_fn"] = loss_fn + self._learner_kwargs["optimizer"] = optimizer + if learning_rate_scheduler is not None: + self._learner_kwargs["learning_rate_scheduler"] = learning_rate_scheduler + self._learner_kwargs[ + "learning_rate_scheduler_interval" + ] = learning_rate_scheduler_interval self._model = model self._learner_state_file: Optional[str] = None if input_state_folder is not None: diff --git a/src/renate/utils/deepspeed.py b/src/renate/utils/deepspeed.py index ed762d09..342e1cc2 100644 --- a/src/renate/utils/deepspeed.py +++ b/src/renate/utils/deepspeed.py @@ -88,7 +88,7 @@ def convert_zero_checkpoint_to_fp32_state_dict( extra_key = search_key(client_state["module"], "extra_state") extra_state = client_state["module"][extra_key] state_dict[extra_key] = extra_state - ## End of modifications + # End of modifications client_state = { key: value for key, value in client_state.items() if key not in deepspeed_states } diff --git a/src/renate/utils/distributed_strategies.py b/src/renate/utils/distributed_strategies.py index 28fe5082..28bf9de4 100644 --- a/src/renate/utils/distributed_strategies.py +++ b/src/renate/utils/distributed_strategies.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 import warnings from typing import Optional -from pytorch_lightning.strategies import Strategy, StrategyRegistry +from pytorch_lightning.strategies import Strategy, StrategyRegistry _SUPPORTED_STRATEGIES = [ "ddp_find_unused_parameters_false", @@ -41,13 +41,13 @@ def create_strategy(devices: int = 1, strategy_name: Optional["str"] = None) -> return None elif strategy_name in ["none", "None", None]: - ## Nothing is specified and devices > 1. Fall back to DDP + # Nothing is specified and devices > 1. Fall back to DDP return StrategyRegistry.get("ddp") elif "deepspeed" in strategy_name: strategy = StrategyRegistry.get(strategy_name) - ## TODO: This should be changed to instantiating Deepspeed and settting it in + # TODO: This should be changed to instantiating Deepspeed and settting it in # the constructor. This works for nowbecause forcing PyTorch optimizer flag isn't used # anywhere by Deepspeed. strategy.config["zero_force_ds_cpu_optimizer"] = False diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 2434c680..5486933f 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -4,10 +4,13 @@ import sys import warnings from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torchmetrics +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler from renate import defaults from renate.benchmark.scenarios import Scenario @@ -100,7 +103,27 @@ def get_loss_fn(config_module: ModuleType, convert: bool, **kwargs: Any) -> torc return loss_fn -def get_metrics(config_module: ModuleType) -> Dict[str, torchmetrics.Metric]: +def get_optimizer( + config_module: ModuleType, **kwargs: Any +) -> Optional[Callable[[List[Parameter]], Optimizer]]: + """Creates partial optimizer object from config.""" + optimizer_fn_name = "optimizer_fn" + if optimizer_fn_name in vars(config_module): + return getattr(config_module, optimizer_fn_name)(**kwargs) + + +def get_learning_rate_scheduler( + config_module: ModuleType, **kwargs: Any +) -> Optional[ + Tuple[Callable[[Optimizer], _LRScheduler], defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE] +]: + """Creates partial learning rate scheduler object from config.""" + lr_scheduler_fn_name = "lr_scheduler_fn" + if lr_scheduler_fn_name in vars(config_module): + return getattr(config_module, lr_scheduler_fn_name)(**kwargs) + + +def get_metrics(config_module: ModuleType) -> Optional[Dict[str, torchmetrics.Metric]]: """Creates and returns a dictionary of metrics.""" metrics_fn_name = "metrics_fn" if metrics_fn_name in vars(config_module): diff --git a/src/renate/utils/optimizer.py b/src/renate/utils/optimizer.py index ebac1e82..cd32e008 100644 --- a/src/renate/utils/optimizer.py +++ b/src/renate/utils/optimizer.py @@ -1,57 +1,33 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import List +from functools import partial +from typing import Callable, List import torch +from torch.nn import Parameter +from torch.optim import Optimizer import renate.defaults as defaults -def create_optimizer( - params: List[torch.nn.Parameter], +def create_partial_optimizer( optimizer: defaults.SUPPORTED_OPTIMIZERS_TYPE = defaults.OPTIMIZER, lr: float = defaults.LEARNING_RATE, momentum: float = defaults.MOMENTUM, weight_decay: float = defaults.WEIGHT_DECAY, -) -> torch.optim.Optimizer: - """Creates optimizer used to train the model. +) -> Callable[[List[Parameter]], Optimizer]: + """Creates a partial optimizer object. Args: - params: The list of parameters to be updated. - optimizer: The name of the optimizer to be used. Currently 'Adam' and 'SGD' are supported. + optimizer: The name of the optimizer to be used. Options: `Adam` or `SGD`. lr: Learning rate to be used. momentum: Value for the momentum hyperparameter (if relevant). weight_decay: Value for the weight_decay hyperparameter (if relevant). """ if optimizer == "SGD": - return torch.optim.SGD(params, lr, momentum, weight_decay) + return partial(torch.optim.SGD, lr=lr, momentum=momentum, weight_decay=weight_decay) elif optimizer == "Adam": - return torch.optim.Adam(params, lr, weight_decay=weight_decay) + return partial(torch.optim.Adam, lr=lr, weight_decay=weight_decay) else: raise ValueError(f"Unknown optimizer: {optimizer}.") - - -def create_scheduler( - optimizer: torch.optim.Optimizer, - scheduler: defaults.SUPPORTED_LEARNING_RATE_SCHEDULERS_TYPE = defaults.LEARNING_RATE_SCHEDULER, - step_size: int = defaults.LEARNING_RATE_SCHEDULER_STEP_SIZE, - gamma: float = defaults.LEARNING_RATE_SCHEDULER_GAMMA, -) -> torch.optim.lr_scheduler._LRScheduler: - """Creates a learning rate scheduler used to train the model. - - Args: - optimizer: The optimizer to be used. - scheduler: The name of the scheduler to be used. - step_size: Period of learning rate decay. - gamma: Value for the gamma hyperparameter (if relevant). - """ - - if scheduler == "ConstantLR": - return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0) - elif scheduler == "ExponentialLR": - return torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma) - elif scheduler == "StepLR": - return torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) - else: - raise ValueError(f"Unknown scheduler: {scheduler}.") diff --git a/test/conftest.py b/test/conftest.py index 659aa990..162a64ac 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import os import shutil -from typing import Callable, Dict +from typing import Callable, Dict, Literal import pytest import torch from pytorch_lightning.loggers import TensorBoardLogger +from renate import defaults from renate.benchmark.models import ( MultiLayerPerceptron, ResNet18, @@ -40,6 +41,7 @@ from renate.updaters.experimental.repeated_distill import RepeatedDistillationLearner from renate.updaters.learner import Learner, ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.optimizer import create_partial_optimizer pytest_plugins = ["helpers_namespace"] @@ -70,166 +72,57 @@ def pytest_collection_modifyitems(config, items): ExperienceReplayLearner: { "memory_size": 30, "memory_batch_size": 20, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - Learner: { - "optimizer": "SGD", - "learning_rate": 1.23, - "momentum": 0.9, - "weight_decay": 0.005, - "batch_size": 10, - "seed": 42, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - GDumbLearner: { - "optimizer": "SGD", - "learning_rate": 1.23, - "momentum": 0.9, - "weight_decay": 0.005, - "batch_size": 10, - "seed": 42, - "memory_size": 30, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - JointLearner: { - "optimizer": "SGD", - "learning_rate": 1.11, - "momentum": 0.4, - "weight_decay": 0.001, - "batch_size": 10, - "seed": 3, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - RepeatedDistillationLearner: { - "optimizer": "SGD", - "learning_rate": 1.23, - "momentum": 0.9, - "weight_decay": 0.005, - "batch_size": 10, - "seed": 42, - "memory_size": 30, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, + Learner: {"batch_size": 10, "seed": 42}, + GDumbLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, + JointLearner: {"batch_size": 10, "seed": 3}, + RepeatedDistillationLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, OfflineExperienceReplayLearner: { "memory_size": 30, "memory_batch_size": 20, "loss_weight_new_data": 0.5, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), }, } AVALANCHE_LEARNER_KWARGS = { AvalancheReplayLearner: { "memory_size": 30, "memory_batch_size": 20, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheEWCLearner: { "ewc_lambda": 0.1, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheLwFLearner: { "alpha": 0.1, "temperature": 2, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), }, AvalancheICaRLLearner: { "memory_size": 30, "memory_batch_size": 20, - "optimizer": "SGD", - "learning_rate": 2.5, - "momentum": 1.3, - "weight_decay": 0.5, "batch_size": 50, "seed": 1, - "loss_fn": torch.nn.CrossEntropyLoss(), }, } LEARNER_HYPERPARAMETER_UPDATES = { - ExperienceReplayLearner: { - "optimizer": "Adam", - "learning_rate": 3.0, - "momentum": 0.5, - "weight_decay": 0.01, - "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - Learner: { - "optimizer": "Adam", - "learning_rate": 3.0, - "weight_decay": 0.01, - "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - GDumbLearner: { - "optimizer": "Adam", - "learning_rate": 2.0, - "momentum": 0.5, - "weight_decay": 0.03, - "batch_size": 128, - "memory_size": 50, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - JointLearner: { - "optimizer": "Adam", - "learning_rate": 2.0, - "weight_decay": 0.01, - "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - RepeatedDistillationLearner: { - "optimizer": "Adam", - "learning_rate": 2.0, - "weight_decay": 0.01, - "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, - OfflineExperienceReplayLearner: { - "optimizer": "Adam", - "learning_rate": 3.0, - "momentum": 0.5, - "weight_decay": 0.01, - "batch_size": 128, - "loss_fn": torch.nn.CrossEntropyLoss(reduction="none"), - }, + ExperienceReplayLearner: {"batch_size": 128}, + Learner: {"batch_size": 128}, + GDumbLearner: {"batch_size": 128, "memory_size": 50}, + JointLearner: {"batch_size": 128}, + RepeatedDistillationLearner: {"batch_size": 128}, + OfflineExperienceReplayLearner: {"batch_size": 128}, } AVALANCHE_LEARNER_HYPERPARAMETER_UPDATES = { - AvalancheEWCLearner: { - "ewc_lambda": 0.3, - }, - AvalancheLwFLearner: { - "alpha": 0.2, - "temperature": 3, - }, + AvalancheEWCLearner: {"ewc_lambda": 0.3}, + AvalancheLwFLearner: {"alpha": 0.2, "temperature": 3}, AvalancheICaRLLearner: {}, AvalancheReplayLearner: {}, } @@ -359,31 +252,6 @@ def get_renate_module_mlp_and_data( return model, train_dataset, test_data -@pytest.helpers.register -def get_renate_module_mlp_data_and_loss( - num_inputs, - num_outputs, - num_hidden_layers, - hidden_size, - train_num_samples, - test_num_samples, - val_num_samples=0, - add_icarl_class_means=False, -): - model, ds, test_data = get_renate_module_mlp_and_data( - num_inputs, - num_outputs, - num_hidden_layers, - hidden_size, - train_num_samples, - test_num_samples, - val_num_samples, - add_icarl_class_means, - ) - - return model, ds, test_data, get_loss_fn() - - @pytest.helpers.register def get_renate_vision_module_and_data( input_size, @@ -406,10 +274,11 @@ def get_renate_vision_module_and_data( @pytest.helpers.register def get_simple_updater( model, + partial_optimizer=None, input_state_folder=None, output_state_folder=None, learner_class=ExperienceReplayLearner, - learner_kwargs={"memory_size": 10, "loss_fn": pytest.helpers.get_loss_fn()}, + learner_kwargs=None, max_epochs=5, train_transform=None, train_target_transform=None, @@ -421,6 +290,8 @@ def get_simple_updater( metric=None, deterministic_trainer=False, ): + if learner_kwargs is None: + learner_kwargs = {"memory_size": 10} transforms_kwargs = { "train_transform": train_transform, "train_target_transform": train_target_transform, @@ -432,6 +303,8 @@ def get_simple_updater( transforms_kwargs["buffer_target_transform"] = buffer_target_transform return SingleTrainingLoopUpdater( model=model, + loss_fn=get_loss_fn(), + optimizer=partial_optimizer or get_partial_optimizer(), learner_class=learner_class, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, @@ -452,7 +325,7 @@ def get_avalanche_updater( input_state_folder=None, output_state_folder=None, learner_class=AvalancheReplayLearner, - learner_kwargs={"memory_size": 10, "loss_fn": torch.nn.CrossEntropyLoss()}, + learner_kwargs=None, max_epochs=5, train_transform=None, train_target_transform=None, @@ -461,6 +334,8 @@ def get_avalanche_updater( early_stopping_enabled=False, metric=None, ): + if learner_kwargs is None: + learner_kwargs = {"memory_size": 10} transforms_kwargs = { "train_transform": train_transform, "train_target_transform": train_target_transform, @@ -469,6 +344,8 @@ def get_avalanche_updater( } return AvalancheModelUpdater( model=model, + loss_fn=get_loss_fn("mean"), + optimizer=get_partial_optimizer(), learner_class=learner_class, learner_kwargs=learner_kwargs, input_state_folder=input_state_folder, @@ -481,6 +358,18 @@ def get_avalanche_updater( ) +@pytest.helpers.register +def get_partial_optimizer( + optimizer: Literal["Adam", "SGD"] = defaults.OPTIMIZER, + lr: float = defaults.LEARNING_RATE, + momentum: float = defaults.MOMENTUM, + weight_decay: float = defaults.WEIGHT_DECAY, +): + return create_partial_optimizer( + optimizer=optimizer, lr=lr, momentum=momentum, weight_decay=weight_decay + ) + + @pytest.helpers.register def check_learner_transforms(learner: Learner, expected_transforms: Dict[str, Callable]): """Checks if the learner transforms match to expected ones. diff --git a/test/renate/data/test_datasets.py b/test/renate/data/test_datasets.py index 95f96a83..351a8e71 100644 --- a/test/renate/data/test_datasets.py +++ b/test/renate/data/test_datasets.py @@ -10,7 +10,7 @@ from torch.utils.data import TensorDataset from renate.data import ImageDataset -from renate.data.datasets import _EnumeratedDataset, _TransformedDataset, IndexedSubsetDataset +from renate.data.datasets import IndexedSubsetDataset, _EnumeratedDataset, _TransformedDataset class MulTransform: diff --git a/test/renate/renate_config_files/config_custom_optimizer.py b/test/renate/renate_config_files/config_custom_optimizer.py new file mode 100644 index 00000000..cdea6820 --- /dev/null +++ b/test/renate/renate_config_files/config_custom_optimizer.py @@ -0,0 +1,43 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from functools import partial +from typing import Callable, List, Optional, Tuple + +import torch +from torch.nn import Parameter +from torch.optim import Optimizer, SGD +from torch.optim.lr_scheduler import StepLR, _LRScheduler + +from dummy_datasets import DummyTorchVisionDataModule +from renate.benchmark.models.mlp import MultiLayerPerceptron +from renate.data.data_module import RenateDataModule +from renate.models import RenateModule + + +def model_fn(model_state_url: Optional[str] = None) -> RenateModule: + if model_state_url is None: + return MultiLayerPerceptron(5 * 5, 10, 0, 64) + state_dict = torch.load(model_state_url) + return MultiLayerPerceptron.from_state_dict(state_dict) + + +def data_module_fn( + data_path: str, + val_size: float = 0.0, + seed: int = 0, +) -> RenateDataModule: + return DummyTorchVisionDataModule(transform=None, val_size=val_size, seed=seed) + + +def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: + if updater.startswith("Avalanche-"): + return torch.nn.CrossEntropyLoss() + return torch.nn.CrossEntropyLoss(reduction="none") + + +def optimizer_fn() -> Callable[[List[Parameter]], Optimizer]: + return partial(SGD, lr=0.01) + + +def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]: + return partial(StepLR, step_size=10, gamma=0.1), "epoch" diff --git a/test/renate/shift/test_detectors.py b/test/renate/shift/test_detectors.py index 82ffd259..751c3c4f 100644 --- a/test/renate/shift/test_detectors.py +++ b/test/renate/shift/test_detectors.py @@ -3,8 +3,8 @@ import pytest import torch -from renate.shift.mmd_detectors import MMDCovariateShiftDetector from renate.shift.ks_detector import KolmogorovSmirnovCovariateShiftDetector +from renate.shift.mmd_detectors import MMDCovariateShiftDetector @pytest.mark.parametrize( diff --git a/test/renate/training/test_run_training.py b/test/renate/training/test_run_training.py index 399d15cd..082bd804 100644 --- a/test/renate/training/test_run_training.py +++ b/test/renate/training/test_run_training.py @@ -23,26 +23,31 @@ from renate.utils.syne_tune import is_syne_tune_config_space config_file = str(Path(__file__).parent.parent / "renate_config_files" / "config.py") +config_file_custom_optimizer = str( + Path(__file__).parent.parent / "renate_config_files" / "config_custom_optimizer.py" +) @pytest.mark.parametrize( - "num_chunks, val_size, raises, fixed_search_space, scheduler", + "num_chunks, val_size, raises, fixed_search_space, scheduler, configuration_file", [ - (2, 0.9, False, False, "rush"), - (1, 0.0, True, False, "rush"), - (2, 0.9, False, True, None), - (2, 0.0, False, True, None), + (2, 0.9, False, False, "rush", config_file), + (1, 0.0, True, False, "rush", config_file), + (2, 0.9, False, True, None, config_file), + (2, 0.0, False, True, None, config_file), + (1, 0.0, False, True, None, config_file_custom_optimizer), ], ids=[ "transfer-hpo-with-val", "transfer-hpo-without-val-raises-exception", "training-single-config-with-val", "training-single-config-without-val", + "training-single-config-without-val-custom-optimizer", ], ) @pytest.mark.parametrize("updater", ("ER", "Avalanche-iCaRL")) def test_run_training_job( - tmpdir, num_chunks, val_size, raises, fixed_search_space, scheduler, updater + tmpdir, num_chunks, val_size, raises, fixed_search_space, scheduler, updater, configuration_file ): """Simply running tuning job to check if anything fails. @@ -50,24 +55,28 @@ def test_run_training_job( Case 2: HPO without validation set fails. Case 3: Training of single configuration with validation set. Case 4: Training of single configuration without validation set. + Case 5: Training of single configuration without validation set using custom optimizer and + learning rate scheduler.. """ state_url = None tmpdir = str(tmpdir) for _ in range(num_chunks): def execute_job(): + config_space = {"val_size": val_size} + if configuration_file == config_file: + config_space["learning_rate"] = ( + 0.1 if fixed_search_space else loguniform(10e-5, 0.1) + ) run_training_job( updater=updater, max_epochs=5, - config_file=config_file, + config_file=configuration_file, input_state_url=state_url, output_state_url=tmpdir, backend="local", mode="max", - config_space={ - "learning_rate": 0.1 if fixed_search_space else loguniform(10e-5, 0.1), - "val_size": val_size, - }, + config_space=config_space, metric="val_accuracy", max_time=30, scheduler=scheduler, diff --git a/test/renate/updaters/avalanche/test_avalanche_learner.py b/test/renate/updaters/avalanche/test_avalanche_learner.py index ae6aaebc..6e0e3eda 100644 --- a/test/renate/updaters/avalanche/test_avalanche_learner.py +++ b/test/renate/updaters/avalanche/test_avalanche_learner.py @@ -80,10 +80,16 @@ def test_update_settings(learner_class): ) plugins = [] expected_max_epochs = 10 + expected_loss_fn = pytest.helpers.get_loss_fn("mean") expected_optimizer = SGD(expected_model.parameters(), lr=0.1) expected_device = torch.device("cpu") expected_eval_every = -1 - learner = learner_class(model=expected_model, **learner_kwargs) + learner = learner_class( + model=expected_model, + optimizer=None, + loss_fn=expected_loss_fn, + **learner_kwargs, + ) avalanche_learner = learner.create_avalanche_learner( plugins=plugins, optimizer=expected_optimizer, @@ -96,6 +102,7 @@ def test_update_settings(learner_class): learner_kwargs, avalanche_learner, expected_model=expected_model, + expected_loss_fn=expected_loss_fn, expected_optimizer=expected_optimizer, expected_max_epochs=expected_max_epochs, expected_device=expected_device, @@ -123,6 +130,7 @@ def test_update_settings(learner_class): learner_kwargs, avalanche_learner, expected_model=expected_model, + expected_loss_fn=expected_loss_fn, expected_optimizer=expected_optimizer, expected_max_epochs=expected_max_epochs, expected_device=expected_device, diff --git a/test/renate/updaters/avalanche/test_avalanche_model_updater.py b/test/renate/updaters/avalanche/test_avalanche_model_updater.py index 000d4bda..2738c733 100644 --- a/test/renate/updaters/avalanche/test_avalanche_model_updater.py +++ b/test/renate/updaters/avalanche/test_avalanche_model_updater.py @@ -87,11 +87,12 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b "memory_size": memory_size, "memory_batch_size": memory_batch_size, "batch_size": batch_size, - "loss_fn": pytest.helpers.get_loss_fn("mean"), } model_updater = ExperienceReplayAvalancheModelUpdater( - output_state_folder=Path(tmpdir) / "0", + output_state_folder=str(Path(tmpdir) / "0"), model=model, + loss_fn=pytest.helpers.get_loss_fn("mean"), + optimizer=pytest.helpers.get_partial_optimizer(), **learner_kwargs, max_epochs=1, accelerator="cpu", @@ -110,9 +111,11 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b del replay_plugin _, dataset = get_model_and_dataset(dataset_size) model_updater = ExperienceReplayAvalancheModelUpdater( - input_state_folder=Path(tmpdir) / str(i), - output_state_folder=Path(tmpdir) / str(i + 1), + input_state_folder=str(Path(tmpdir) / str(i)), + output_state_folder=str(Path(tmpdir) / str(i + 1)), model=model, + loss_fn=pytest.helpers.get_loss_fn("mean"), + optimizer=pytest.helpers.get_partial_optimizer(), **learner_kwargs, max_epochs=1, ) diff --git a/test/renate/updaters/experimental/test_er.py b/test/renate/updaters/experimental/test_er.py index ae89346e..1459f0a4 100644 --- a/test/renate/updaters/experimental/test_er.py +++ b/test/renate/updaters/experimental/test_er.py @@ -36,10 +36,10 @@ def test_er_overall_memory_size_after_update(batch_size, memory_size, memory_bat "memory_size": memory_size, "memory_batch_size": memory_batch_size, "batch_size": batch_size, - "loss_fn": pytest.helpers.get_loss_fn(), } model_updater = pytest.helpers.get_simple_updater( model=model, + partial_optimizer=pytest.helpers.get_partial_optimizer(), learner_class=ExperienceReplayLearner, learner_kwargs=learner_kwargs, max_epochs=1, @@ -94,16 +94,11 @@ def test_er_validation_buffer(tmpdir): [ [ ExperienceReplayLearner, - { - "alpha": 0.2, - "memory_size": 10, - "memory_batch_size": 10, - "loss_fn": pytest.helpers.get_loss_fn(), - }, + {"alpha": 0.2, "memory_size": 10, "memory_batch_size": 10}, ], [ DarkExperienceReplayLearner, - {"alpha": 0.1, "beta": 0.3, "memory_size": 42, "loss_fn": pytest.helpers.get_loss_fn()}, + {"alpha": 0.1, "beta": 0.3, "memory_size": 42}, ], [ CLSExperienceReplayLearner, @@ -115,18 +110,11 @@ def test_er_validation_buffer(tmpdir): "stable_model_update_probability": 0.3, "plastic_model_update_probability": 0.5, "memory_size": 42, - "loss_fn": pytest.helpers.get_loss_fn(), }, ], [ PooledOutputDistillationExperienceReplayLearner, - { - "alpha": 0.3, - "distillation_type": "pixel", - "normalize": False, - "memory_size": 42, - "loss_fn": pytest.helpers.get_loss_fn(), - }, + {"alpha": 0.3, "distillation_type": "pixel", "normalize": False, "memory_size": 42}, ], [ SuperExperienceReplayLearner, @@ -144,7 +132,6 @@ def test_er_validation_buffer(tmpdir): "pod_distillation_type": "pixel", "pod_normalize": False, "memory_size": 42, - "loss_fn": pytest.helpers.get_loss_fn(), }, ], ], @@ -155,9 +142,19 @@ def test_er_components_save_and_load(tmpdir, cls, kwargs): model = pytest.helpers.get_renate_module_mlp( num_inputs=10, num_outputs=10, hidden_size=32, num_hidden_layers=3 ) - learner = cls(model=model, **kwargs) + learner = cls( + model=model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + **kwargs, + ) torch.save(learner.state_dict(), os.path.join(tmpdir, "learner.pt")) - learner = cls(model=model, **kwargs) + learner = cls( + model=model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + **kwargs, + ) learner.load_state_dict(torch.load(os.path.join(tmpdir, "learner.pt"))) if isinstance(learner, ExperienceReplayLearner) and not isinstance( learner, DarkExperienceReplayLearner diff --git a/test/renate/updaters/experimental/test_fine_tuning.py b/test/renate/updaters/experimental/test_fine_tuning.py index 18098686..292cdb1e 100644 --- a/test/renate/updaters/experimental/test_fine_tuning.py +++ b/test/renate/updaters/experimental/test_fine_tuning.py @@ -31,6 +31,7 @@ def test_fine_tuning_updater(devices, strategy, accelerator): model_updater = FineTuningModelUpdater( model, + optimizer=pytest.helpers.get_partial_optimizer(), loss_fn=pytest.helpers.get_loss_fn(), max_epochs=1, devices=devices, diff --git a/test/renate/updaters/experimental/test_joint.py b/test/renate/updaters/experimental/test_joint.py index c31e69c9..fb0e62e0 100644 --- a/test/renate/updaters/experimental/test_joint.py +++ b/test/renate/updaters/experimental/test_joint.py @@ -26,8 +26,9 @@ def test_joint_learner_memory_append(): dataset_len = len(dataset) model_updater = pytest.helpers.get_simple_updater( model=model, + partial_optimizer=pytest.helpers.get_partial_optimizer(), learner_class=JointLearner, - learner_kwargs={"loss_fn": pytest.helpers.get_loss_fn()}, + learner_kwargs={}, max_epochs=1, ) model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) @@ -41,11 +42,9 @@ def test_joint_learner_model_reset(): model, dataset = get_model_and_dataset() model_updater = pytest.helpers.get_simple_updater( model=model, + partial_optimizer=pytest.helpers.get_partial_optimizer(lr=0.0), learner_class=JointLearner, - learner_kwargs={ - "learning_rate": 0.0, - "loss_fn": pytest.helpers.get_loss_fn(), - }, + learner_kwargs={}, max_epochs=1, ) model = model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) diff --git a/test/renate/updaters/experimental/test_repeated_distill.py b/test/renate/updaters/experimental/test_repeated_distill.py index 4d48d0cc..3565bf29 100644 --- a/test/renate/updaters/experimental/test_repeated_distill.py +++ b/test/renate/updaters/experimental/test_repeated_distill.py @@ -28,9 +28,14 @@ def test_dmc_runs_end_to_end(): data.append(ds) val = ConcatDataset(data) - loss_fn = pytest.helpers.get_loss_fn() updater = RepeatedDistillationModelUpdater( - model=mlp, memory_size=300, batch_size=20, max_epochs=5, loss_fn=loss_fn, accelerator="cpu" + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + model=mlp, + memory_size=300, + batch_size=20, + max_epochs=5, + accelerator="cpu", ) for i in range(len(data)): @@ -43,9 +48,13 @@ def test_dmc_memory_size_after_update(memory_size, dataset_size): model = pytest.helpers.get_renate_module_mlp( num_inputs=10, num_outputs=3, hidden_size=20, num_hidden_layers=1 ) - loss_fn = pytest.helpers.get_loss_fn() model_updater = RepeatedDistillationModelUpdater( - model=model, memory_size=memory_size, max_epochs=1, loss_fn=loss_fn, accelerator="cpu" + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + model=model, + memory_size=memory_size, + max_epochs=1, + accelerator="cpu", ) datasets = [ TensorDataset( @@ -63,7 +72,7 @@ def test_dmc_memory_size_after_update(memory_size, dataset_size): @pytest.mark.parametrize("provide_folder", [True, False]) def test_dmc_model_updater(tmpdir, provide_folder): - model, train_dataset, test_data, loss_fn = pytest.helpers.get_renate_module_mlp_data_and_loss( + model, train_dataset, test_data = pytest.helpers.get_renate_module_mlp_and_data( num_inputs=10, num_outputs=10, hidden_size=32, @@ -73,7 +82,8 @@ def test_dmc_model_updater(tmpdir, provide_folder): ) model_updater = RepeatedDistillationModelUpdater( model, - loss_fn=loss_fn, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), memory_size=50, max_epochs=1, output_state_folder=defaults.output_state_folder(tmpdir) if provide_folder else None, @@ -87,7 +97,7 @@ def test_dmc_model_updater(tmpdir, provide_folder): def test_continuation_of_training_with_dmc_model_updater(tmpdir): - model, train_dataset, _, loss_fn = pytest.helpers.get_renate_module_mlp_data_and_loss( + model, train_dataset, _ = pytest.helpers.get_renate_module_mlp_and_data( num_inputs=10, num_outputs=10, hidden_size=32, @@ -98,19 +108,21 @@ def test_continuation_of_training_with_dmc_model_updater(tmpdir): state_url = defaults.input_state_folder(tmpdir) model_updater = RepeatedDistillationModelUpdater( model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), memory_size=50, max_epochs=1, output_state_folder=state_url, - loss_fn=loss_fn, accelerator="cpu", ) model = model_updater.update(train_dataset, task_id=defaults.TASK_ID) model_updater = RepeatedDistillationModelUpdater( model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), memory_size=50, max_epochs=1, input_state_folder=state_url, - loss_fn=loss_fn, accelerator="cpu", ) model_updater.update(train_dataset, task_id=defaults.TASK_ID) diff --git a/test/renate/updaters/test_learner.py b/test/renate/updaters/test_learner.py index e95cffcd..0273702e 100644 --- a/test/renate/updaters/test_learner.py +++ b/test/renate/updaters/test_learner.py @@ -3,8 +3,8 @@ from typing import Any, Dict, Tuple, Type import pytest -from conftest import LEARNER_KWARGS, LEARNERS +from conftest import LEARNERS, LEARNER_KWARGS from renate.models import RenateModule from renate.updaters.learner import Learner @@ -16,7 +16,12 @@ def get_model_and_learner_and_learner_kwargs( model = pytest.helpers.get_renate_module_mlp( num_inputs=1, num_outputs=1, hidden_size=1, num_hidden_layers=1 ) - learner = learner_class(model=model, **learner_kwargs) + learner = learner_class( + model=model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + **learner_kwargs, + ) return model, learner, learner_kwargs @@ -33,7 +38,7 @@ def check_learner_variables(learner: Learner, expected_variable_values: Dict[str @pytest.mark.parametrize("learner_class", LEARNERS) -def test_save_and_load_learner(tmpdir, learner_class): +def test_save_and_load_learner(learner_class): model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) checkpoint_dict = {} learner.on_save_checkpoint(checkpoint=checkpoint_dict) diff --git a/test/renate/updaters/test_model_updater.py b/test/renate/updaters/test_model_updater.py index 66b1dfc5..7e541f91 100644 --- a/test/renate/updaters/test_model_updater.py +++ b/test/renate/updaters/test_model_updater.py @@ -73,7 +73,7 @@ def test_deterministic_updater(): def test_model_updater_with_early_stopping( use_val, early_stopping_enabled, metric_monitored, updater_type ): - model, train_dataset, val_dataset, loss = pytest.helpers.get_renate_module_mlp_data_and_loss( + model, train_dataset, val_dataset = pytest.helpers.get_renate_module_mlp_and_data( num_inputs=10, num_outputs=10, hidden_size=8, @@ -88,7 +88,8 @@ def test_model_updater_with_early_stopping( if updater_type == "DMC": model_updater = RepeatedDistillationModelUpdater( model=model, - loss_fn=loss, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), memory_size=50, max_epochs=max_epochs, early_stopping_enabled=early_stopping_enabled, diff --git a/test/renate/utils/test_distributed_strategies.py b/test/renate/utils/test_distributed_strategies.py index 716e19c9..34610ef2 100644 --- a/test/renate/utils/test_distributed_strategies.py +++ b/test/renate/utils/test_distributed_strategies.py @@ -3,9 +3,9 @@ import pytest from renate.utils.distributed_strategies import ( - create_strategy, _SUPPORTED_STRATEGIES, _UNSUPPORTED_STRATEGIES, + create_strategy, ) diff --git a/test/renate/utils/test_optimizer.py b/test/renate/utils/test_optimizer.py index a3de7adc..666c45c9 100644 --- a/test/renate/utils/test_optimizer.py +++ b/test/renate/utils/test_optimizer.py @@ -4,7 +4,7 @@ import torch from renate.defaults import TASK_ID -from renate.utils.optimizer import create_optimizer, create_scheduler +from renate.utils.optimizer import create_partial_optimizer @pytest.mark.parametrize( @@ -20,47 +20,11 @@ def test_make_optimizer_with_different_configurations(optimizer, kwargs): ) params = model.get_params(task_id=TASK_ID) - opt = create_optimizer(params, optimizer=optimizer, **kwargs) + opt = create_partial_optimizer(optimizer=optimizer, **kwargs)(params) assert isinstance(opt, torch.optim.Optimizer) def test_unknown_optimizer_raises_error(): - model = pytest.helpers.get_renate_module_mlp( - num_outputs=10, num_inputs=10, hidden_size=32, num_hidden_layers=3 - ) - params = model.get_params(task_id=TASK_ID) - - with pytest.raises(ValueError): - create_optimizer( - params, optimizer="UNKNOWN_OPTIMIZER", lr=0.01, momentum=0.0, weight_decay=0.5 - ) - - -@pytest.mark.parametrize( - "scheduler,kwargs", - [ - ("ConstantLR", {}), - ("ExponentialLR", {"gamma": 0.5}), - ("StepLR", {"step_size": 10, "gamma": 0.5}), - ], -) -def test_make_scheduler_with_different_configurations(scheduler, kwargs): - model = pytest.helpers.get_renate_module_mlp( - num_outputs=10, num_inputs=10, hidden_size=32, num_hidden_layers=3 - ) - params = model.get_params(task_id=TASK_ID) - opt = create_optimizer(params, optimizer="SGD", lr=0.01, momentum=0.0, weight_decay=0.5) - - sch = create_scheduler(opt, scheduler=scheduler, **kwargs) - assert isinstance(sch, torch.optim.lr_scheduler._LRScheduler) - - -def test_unknown_scheduler_raises_error(): - model = pytest.helpers.get_renate_module_mlp( - num_outputs=10, num_inputs=10, hidden_size=32, num_hidden_layers=3 - ) - params = model.get_params(task_id=TASK_ID) - opt = create_optimizer(params, optimizer="SGD", lr=0.01, momentum=0.0, weight_decay=0.5) - - with pytest.raises(ValueError): - create_scheduler(opt, scheduler="UNKNOWN_SCHEDULER", gamma=0.5) + optimizer_name = "Unknown Optimizer" + with pytest.raises(ValueError, match=f"Unknown optimizer: {optimizer_name}."): + create_partial_optimizer(optimizer=optimizer_name, lr=0.01, momentum=0.0, weight_decay=0.5) From 8023952d99365145565047f7ccd25468565cc58d Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Fri, 16 Jun 2023 08:29:10 +0200 Subject: [PATCH 38/89] Flag to remove intermediate tasks' states (#289) --- src/renate/benchmark/experimentation.py | 16 +++++++++++++++- src/renate/defaults.py | 1 + 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 213a8527..a4d288ff 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -149,6 +149,7 @@ def execute_experiment_job( job_name: str = defaults.JOB_NAME, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, + retain_intermediate_state: bool = defaults.RETAIN_INTERMEDIATE_STATE, ) -> None: """Executes the experiment job. @@ -179,6 +180,11 @@ def execute_experiment_job( deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. In this function this parameter is set to True by default. job_name: Name of the experiment job. + strategy: String denoting lightning distributed strategy. + precision: String for which precision to use. + retain_intermediate_state: Flag to retain models and buffer states after each + task update. This is useful when training with large datasets that might cause storage + issues. """ assert ( mode in defaults.SUPPORTED_TUNING_MODE @@ -206,6 +212,7 @@ def execute_experiment_job( seed=seed, strategy=strategy, precision=precision, + retain_intermediate_state=retain_intermediate_state, ) _execute_experiment_job_remotely( job_name=job_name, @@ -232,6 +239,7 @@ def execute_experiment_job( instance_max_time=instance_max_time, strategy=strategy, precision=precision, + retain_intermediate_state=retain_intermediate_state, ) @@ -254,6 +262,7 @@ def _execute_experiment_job_locally( deterministic_trainer: bool, strategy: str, precision: str, + retain_intermediate_state: bool, ) -> None: """Runs an experiment, combining hyperparameter tuning and model for multiple updates. @@ -347,7 +356,8 @@ def _execute_experiment_job_locally( deterministic_trainer=deterministic_trainer, ) move_to_uri(output_state_url, input_state_url) - copy_to_uri(input_state_url, update_url) + if retain_intermediate_state: + copy_to_uri(input_state_url, update_url) model = get_model( config_module, model_state_url=model_url, @@ -374,6 +384,10 @@ def _execute_experiment_job_locally( cumulative_metrics = create_cumulative_metrics("classification") df = cumulative_metrics_summary(results, cumulative_metrics, num_updates - 1) save_pandas_df_to_csv(df, defaults.metric_summary_file(logs_url)) + if not retain_intermediate_state: + move_to_uri( + defaults.hpo_file(input_state_url), defaults.logs_folder(experiment_outputs_url) + ) logger.info("### Cumulative results: ###") logger.info(df) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index c2852964..5c6208f6 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -55,6 +55,7 @@ JOB_NAME = "renate" SUPPORTED_TUNING_MODE = ["min", "max"] SUPPORTED_TUNING_MODE_TYPE = Literal["min", "max"] +RETAIN_INTERMEDIATE_STATE = True SUPPORTED_BACKEND = ["local", "sagemaker"] SUPPORTED_BACKEND_TYPE = Literal["local", "sagemaker"] From f542c34770b898fff8170c24077f7ca21eb635a4 Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 16 Jun 2023 17:01:34 +0200 Subject: [PATCH 39/89] Missing Argument Doesn't Allow for Remote Experiments (#304) --- src/renate/benchmark/experimentation.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index a4d288ff..f8974d14 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -133,6 +133,7 @@ def execute_experiment_job( num_updates: int, working_directory: Optional[str] = defaults.WORKING_DIRECTORY, requirements_file: Optional[str] = None, + dependencies: Optional[List[str]] = None, role: Optional[str] = None, instance_type: str = defaults.INSTANCE_TYPE, instance_count: int = defaults.INSTANCE_COUNT, @@ -165,6 +166,8 @@ def execute_experiment_job( num_updates: Number of updates of the experiment job. working_directory: Path to the working directory. requirements_file: Path to the requirements file. + dependencies: (SageMaker backend only) List of strings containing absolute or relative paths + to files and directories that will be uploaded as part of the SageMaker training job. role: Role of the experiment job. instance_type: Instance type of the experiment job. instance_count: Instance count of the experiment job. @@ -222,6 +225,7 @@ def execute_experiment_job( metric=metric, num_updates=num_updates, working_directory=working_directory, + dependencies=dependencies or [], config_space=config_space, max_time=max_time, max_num_trials_started=max_num_trials_started, From e4984aae4cbfdf55b2a87588376b39008297f809 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Mon, 19 Jun 2023 13:15:57 +0200 Subject: [PATCH 40/89] Using HuggingFace ViT implementation (#303) --- .../benchmark/models/vision_transformer.py | 132 +++++++++--------- .../models/test_vision_transformer.py | 16 ++- 2 files changed, 77 insertions(+), 71 deletions(-) diff --git a/src/renate/benchmark/models/vision_transformer.py b/src/renate/benchmark/models/vision_transformer.py index 6b16fc53..4ad4aa19 100644 --- a/src/renate/benchmark/models/vision_transformer.py +++ b/src/renate/benchmark/models/vision_transformer.py @@ -1,16 +1,46 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from functools import partial -from typing import Any, Callable, List, Optional +from typing import Any, Optional, Tuple, Union -import torch.nn as nn -from torchvision.models.vision_transformer import ConvStemConfig, WeightsEnum -from torchvision.models.vision_transformer import VisionTransformer as _VisionTransformer +import torch +from transformers import ViTConfig, ViTModel +from transformers.modeling_outputs import BaseModelOutputWithPooling from renate.benchmark.models.base import RenateBenchmarkingModule from renate.models.prediction_strategies import PredictionStrategy +class FeatureExtractorViTModel(ViTModel): + """This class directly outputs [CLS] features directly""" + + def forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + """Output has patch embeddings and the pooled output. We extract pooled CLS out by + taking the second element. + """ + out_to_filter = super().forward( + pixel_values, + bool_masked_pos, + head_mask, + output_attentions, + output_hidden_states, + interpolate_pos_encoding, + return_dict, + ) + + if isinstance(out_to_filter, BaseModelOutputWithPooling): + return out_to_filter.pooler_output + return out_to_filter[1] + + class VisionTransformer(RenateBenchmarkingModule): """Vision Transformer base model. @@ -20,6 +50,8 @@ class VisionTransformer(RenateBenchmarkingModule): arXiv preprint arXiv:2010.11929 (2020). Args: + pretrained_name: A string that denotes which pretrained model from the HF hub to use. + If provided, it overrides other arguments about architecture. image_size: Size of the input image. patch_size: Size of the patches. num_layers: Number of Encoder layers. @@ -29,11 +61,6 @@ class VisionTransformer(RenateBenchmarkingModule): dropout: Dropout probability. attention_dropout: Dropout probability for the attention in the Multi-head Attention layer. num_outputs: Size of the output. - representation_size: If specified, the model will return a linear projection of the last - hidden state. - norm_layer: Normalization layer. - conv_stem_configs: List of ConvStemConfig. Each ConvStemConfig corresponds to a - convolutional stem. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. add_icarl_class_means: If ``True``, additional parameters used only by the @@ -42,6 +69,7 @@ class VisionTransformer(RenateBenchmarkingModule): def __init__( self, + pretrained_model_name_or_path: Optional[str] = None, image_size: int = 32, patch_size: int = 4, num_layers: int = 12, @@ -51,29 +79,34 @@ def __init__( dropout: float = 0.1, attention_dropout: float = 0.1, num_outputs: int = 10, - representation_size: Optional[int] = None, - norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6), - conv_stem_configs: Optional[List[ConvStemConfig]] = None, - weights: Optional[WeightsEnum] = None, prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ) -> None: - model = _VisionTransformer( - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - dropout=dropout, - attention_dropout=attention_dropout, - num_classes=num_outputs, - representation_size=representation_size, - norm_layer=norm_layer, - conv_stem_configs=conv_stem_configs, - ) + if pretrained_model_name_or_path: + model = FeatureExtractorViTModel.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, return_dict=False + ) + else: + model_config = ViTConfig( + hidden_size=hidden_dim, + num_hidden_layers=num_layers, + num_attention_heads=num_heads, + intermediate_size=mlp_dim, + hidden_act="gelu", + hidden_dropout_prob=dropout, + attention_probs_dropout_prob=attention_dropout, + layer_norm_eps=1e-6, + image_size=image_size, + patch_size=patch_size, + num_channels=3, + qkv_bias=True, + return_dict=False, + ) + + model = FeatureExtractorViTModel(config=model_config) + super().__init__( - embedding_size=model.heads.head.in_features, + embedding_size=hidden_dim, num_outputs=num_outputs, constructor_arguments={ "image_size": image_size, @@ -84,17 +117,11 @@ def __init__( "mlp_dim": mlp_dim, "dropout": dropout, "attention_dropout": attention_dropout, - "representation_size": representation_size, - "norm_layer": norm_layer, - "conv_stem_configs": conv_stem_configs, }, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) self._backbone = model - if weights: - self._backbone.load_state_dict(weights.get_state_dict()) - self._backbone.heads.head = nn.Identity() class VisionTransformerCIFAR(VisionTransformer): @@ -113,12 +140,7 @@ def __init__(self, **kwargs: Any) -> None: class VisionTransformerB16(VisionTransformer): def __init__(self, **kwargs: Any) -> None: super().__init__( - image_size=224, - patch_size=16, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, + pretrained_model_name_or_path="google/vit-base-patch16-224", **kwargs, ) @@ -126,12 +148,7 @@ def __init__(self, **kwargs: Any) -> None: class VisionTransformerB32(VisionTransformer): def __init__(self, **kwargs: Any) -> None: super().__init__( - image_size=224, - patch_size=32, - num_layers=12, - num_heads=12, - hidden_dim=768, - mlp_dim=3072, + pretrained_model_name_or_path="google/vit-base-patch32-224-in21k", **kwargs, ) @@ -139,12 +156,7 @@ def __init__(self, **kwargs: Any) -> None: class VisionTransformerL16(VisionTransformer): def __init__(self, **kwargs: Any) -> None: super().__init__( - image_size=224, - patch_size=16, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, + pretrained_model_name_or_path="google/vit-large-patch16-224-in21k", **kwargs, ) @@ -152,12 +164,7 @@ def __init__(self, **kwargs: Any) -> None: class VisionTransformerL32(VisionTransformer): def __init__(self, **kwargs: Any) -> None: super().__init__( - image_size=224, - patch_size=32, - num_layers=24, - num_heads=16, - hidden_dim=1024, - mlp_dim=4096, + pretrained_model_name_or_path="google/vit-large-patch32-224-in21k", **kwargs, ) @@ -165,11 +172,6 @@ def __init__(self, **kwargs: Any) -> None: class VisionTransformerH14(VisionTransformer): def __init__(self, **kwargs: Any) -> None: super().__init__( - image_size=224, - patch_size=14, - num_layers=32, - num_heads=16, - hidden_dim=1280, - mlp_dim=5120, + pretrained_model_name_or_path="google/vit-huge-patch14-224-in21k", **kwargs, ) diff --git a/test/renate/benchmark/models/test_vision_transformer.py b/test/renate/benchmark/models/test_vision_transformer.py index cbbdfea5..8ebe19c6 100644 --- a/test/renate/benchmark/models/test_vision_transformer.py +++ b/test/renate/benchmark/models/test_vision_transformer.py @@ -37,15 +37,19 @@ def test_renate_vision_transformer_fwd(sub_class, input_dim): assert y_hat.shape[1] == 10 +# The following numbers have been computed by +# for m in [VisionTransformerB16, VisionTransformerB32, VisionTransformerCIFAR, +# VisionTransformerH14, VisionTransformerL16, VisionTransformerL32]: +# print(len(list(m()._backbone.parameters()))) @pytest.mark.parametrize( "sub_class, expected_num_params", [ - ["visiontransformercifar", 42], - ["visiontransformerb16", 150], - ["visiontransformerb32", 150], - ["visiontransformerl16", 294], - ["visiontransformerl32", 294], - ["visiontransformerh14", 390], + ["visiontransformercifar", 56], + ["visiontransformerb16", 200], + ["visiontransformerb32", 200], + ["visiontransformerl16", 392], + ["visiontransformerl32", 392], + ["visiontransformerh14", 520], ], ) def test_renate_vision_transformer_get_params(sub_class, expected_num_params): From c5265101a894f5df24cd9b6e471dc9f35ae0f870 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 21 Jun 2023 15:44:39 +0200 Subject: [PATCH 41/89] Introduce `RenateLightningModule` (#301) --- examples/getting_started/renate_config.py | 2 +- examples/nlp_finetuning/start.py | 4 +- .../renate_config.py | 9 +- examples/train_mlp_locally/renate_config.py | 9 +- src/renate/benchmark/experiment_config.py | 11 +- src/renate/benchmark/experimentation.py | 28 +- src/renate/benchmark/scenarios.py | 52 ++-- src/renate/cli/run_training.py | 2 + src/renate/data/data_module.py | 17 +- src/renate/defaults.py | 1 - src/renate/evaluation/evaluator.py | 31 +- src/renate/evaluation/metrics/utils.py | 35 --- src/renate/training/training.py | 4 +- .../updaters/avalanche/model_updater.py | 20 +- src/renate/updaters/experimental/er.py | 15 +- src/renate/updaters/experimental/gdumb.py | 15 +- src/renate/updaters/experimental/joint.py | 15 +- .../updaters/experimental/offline_er.py | 15 +- .../updaters/experimental/repeated_distill.py | 37 ++- src/renate/updaters/learner.py | 278 +++++++++++++----- src/renate/updaters/model_updater.py | 20 +- src/renate/utils/avalanche.py | 14 +- src/renate/utils/module.py | 4 + .../benchmark/test_experimentation_config.py | 5 + test/renate/renate_config_files/config.py | 7 +- .../config_custom_optimizer.py | 7 +- .../renate_config_files/config_scenario.py | 14 +- test/renate/training/test_run_training.py | 2 +- test/renate/updaters/test_model_updater.py | 4 +- test/renate/utils/test_metrics_utils.py | 64 ---- 30 files changed, 472 insertions(+), 269 deletions(-) delete mode 100644 src/renate/evaluation/metrics/utils.py delete mode 100644 test/renate/utils/test_metrics_utils.py diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index 5045cb9d..9a1e883b 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -93,7 +93,7 @@ def buffer_transform() -> Callable: def metrics_fn() -> Dict: - return {"my_accuracy": Accuracy()} + return {"accuracy": Accuracy()} def loss_fn() -> torch.nn.Module: diff --git a/examples/nlp_finetuning/start.py b/examples/nlp_finetuning/start.py index 40ae268f..cb6a8951 100644 --- a/examples/nlp_finetuning/start.py +++ b/examples/nlp_finetuning/start.py @@ -25,8 +25,8 @@ run_training_job( config_space=config_space, - mode="max", - metric="val_accuracy", + mode="min", + metric="val_loss", updater="ER", # we train with Experience Replay max_epochs=5, config_file="renate_config.py", diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 49567c9d..4cf5883a 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional +from typing import Callable, Dict, Optional import torch +from torchmetrics import Accuracy from torchvision import transforms import renate.defaults as defaults @@ -36,7 +37,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> ) class_incremental_scenario = ClassIncrementalScenario( data_module=data_module, - class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), chunk_id=chunk_id, ) return class_incremental_scenario @@ -65,3 +66,7 @@ def buffer_transform() -> Callable: def loss_fn() -> torch.nn.Module: return torch.nn.CrossEntropyLoss(reduction="none") + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index 772577d3..58b5f42a 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional +from typing import Callable, Dict, Optional import torch +from torchmetrics import Accuracy from torchvision.transforms import transforms from renate import defaults @@ -27,7 +28,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> class_incremental_scenario = ClassIncrementalScenario( data_module=data_module, - class_groupings=[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], + class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), chunk_id=chunk_id, ) return class_incremental_scenario @@ -52,3 +53,7 @@ def train_transform() -> Callable: def loss_fn() -> torch.nn.Module: return torch.nn.CrossEntropyLoss(reduction="none") + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 4cb3ca4d..d5009af4 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -1,9 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch import wild_time_data +from torchmetrics import Accuracy from torchvision.transforms import transforms from transformers import AutoTokenizer @@ -291,7 +292,7 @@ def _get_normalize_transform(dataset_name): ) -def train_transform(dataset_name: str) -> Optional[transforms.Compose]: +def train_transform(dataset_name: str) -> Optional[Callable]: """Returns a transform function to be used in the training.""" if dataset_name in [ "MNIST", @@ -317,7 +318,7 @@ def train_transform(dataset_name: str) -> Optional[transforms.Compose]: raise ValueError(f"Unknown dataset `{dataset_name}`.") -def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: +def test_transform(dataset_name: str) -> Optional[Callable]: """Returns a transform function to be used for validation or testing.""" if dataset_name in [ "MNIST", @@ -335,3 +336,7 @@ def test_transform(dataset_name: str) -> Optional[transforms.Normalize]: ] ) raise ValueError(f"Unknown dataset `{dataset_name}`.") + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index f8974d14..474aef08 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -47,22 +47,16 @@ def experiment_config_file(): return str(Path(renate.__path__[0]) / "benchmark" / "experiment_config.py") -def create_cumulative_metrics(task: defaults.SUPPORTED_TASKS_TYPE) -> List[Tuple[str, Callable]]: +def create_cumulative_metrics() -> List[Tuple[str, Callable]]: """Gets the cumulative metrics for a given task along with a name of the metric to include in any potential results table. - - Args: - task: Whether classification or regression, for now. """ - if task == "classification": - return [ - ("Average Accuracy", average_accuracy), - ("Forgetting", forgetting), - ("Forward Transfer", forward_transfer), - ("Backward Transfer", backward_transfer), - ] - else: - raise NotImplementedError(f"Task {task} not implemented.") + return [ + ("Average Accuracy", average_accuracy), + ("Forgetting", forgetting), + ("Forward Transfer", forward_transfer), + ("Backward Transfer", backward_transfer), + ] def cumulative_metrics_summary( @@ -183,8 +177,10 @@ def execute_experiment_job( deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. In this function this parameter is set to True by default. job_name: Name of the experiment job. - strategy: String denoting lightning distributed strategy. - precision: String for which precision to use. + strategy: Name of the distributed training strategy to use. + `More details `__ + precision: Type of bit precision to use. + `More details `__ retain_intermediate_state: Flag to retain models and buffer states after each task update. This is useful when training with large datasets that might cause storage issues. @@ -385,7 +381,7 @@ def _execute_experiment_job_locally( logger.info(f"### Results after update {update_id + 1}: ###") logger.info(df) - cumulative_metrics = create_cumulative_metrics("classification") + cumulative_metrics = create_cumulative_metrics() df = cumulative_metrics_summary(results, cumulative_metrics, num_updates - 1) save_pandas_df_to_csv(df, defaults.metric_summary_file(logs_url)) if not retain_intermediate_state: diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index b90031fd..7a57b447 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import abc -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -26,7 +26,7 @@ class Scenario(abc.ABC): subsequent instantiations. The seed argument is required for these scenarios. Args: - data_module: The source RenateDataModule for the the user data. + data_module: The source RenateDataModule for the user data. num_tasks: The total number of expected tasks for experimentation. chunk_id: The data chunk to load in for the training or validation data. seed: Seed used to fix random number generation. @@ -45,9 +45,12 @@ def __init__( self._verify_chunk_id(chunk_id) self._chunk_id = chunk_id self._seed = seed - self._train_data: Dataset = None - self._val_data: Dataset = None - self._test_data: List[Dataset] = None + self._train_data: Optional[Dataset] = None + self._val_data: Optional[Dataset] = None + self._test_data: Optional[List[Dataset]] = None + self._train_collate_fn: Optional[Callable] = None + self._val_collate_fn: Optional[Callable] = None + self._test_collate_fn: Optional[Callable] = None def prepare_data(self) -> None: """Downloads datasets.""" @@ -56,7 +59,10 @@ def prepare_data(self) -> None: @abc.abstractmethod def setup(self) -> None: """Sets up the scenario.""" - pass + self._data_module.setup() + self._train_collate_fn = self._data_module.train_collate_fn() + self._val_collate_fn = self._data_module.val_collate_fn() + self._test_collate_fn = self._data_module.test_collate_fn() def train_data(self) -> Dataset: """Returns training dataset with respect to current `chunk_id`.""" @@ -70,6 +76,18 @@ def test_data(self) -> List[Dataset]: """Returns the test data with respect to all tasks in `num_tasks`.""" return self._test_data + def train_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for train DataLoader.""" + return self._train_collate_fn + + def val_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for validation DataLoader.""" + return self._val_collate_fn + + def test_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for test DataLoader.""" + return self._test_collate_fn + def _verify_chunk_id(self, chunk_id: int) -> None: """A helper function to verify that the `chunk_id` is valid.""" assert 0 <= chunk_id < self._num_tasks @@ -90,7 +108,7 @@ class BenchmarkScenario(Scenario): """ def setup(self) -> None: - self._data_module.setup() + super().setup() self._train_data = self._data_module.train_data() self._val_data = self._data_module.val_data() self._test_data = self._data_module._test_data @@ -108,7 +126,7 @@ class ClassIncrementalScenario(Scenario): and `y` is the class id. Args: - data_module: The source RenateDataModule for the the user data. + data_module: The source RenateDataModule for the user data. chunk_id: The data chunk to load in for the training or validation data. class_groupings: List of lists, describing the division of the classes for respective tasks. """ @@ -117,14 +135,14 @@ def __init__( self, data_module: RenateDataModule, chunk_id: int, - class_groupings: Tuple[Tuple[int]], + class_groupings: Tuple[Tuple[int, ...], ...], ) -> None: super().__init__(data_module, len(class_groupings), chunk_id) self._class_groupings = class_groupings def setup(self) -> None: """Make assignments: val/train/test splits.""" - self._data_module.setup() + super().setup() self._train_data = self._get_task_subset( self._data_module.train_data(), chunk_id=self._chunk_id ) @@ -178,7 +196,7 @@ def __init__( self._transforms = transforms def setup(self) -> None: - self._data_module.setup() + super().setup() self._split_and_assign_train_and_val_data() self._train_data = _TransformedDataset( self._train_data, transform=self._transforms[self._chunk_id] @@ -249,7 +267,7 @@ class IIDScenario(Scenario): def setup(self) -> None: """Make assignments: val/train/test splits.""" - self._data_module.setup() + super().setup() proportions = [1 / self._num_tasks for _ in range(self._num_tasks)] self._train_data = randomly_split_data( self._data_module.train_data(), proportions, self._seed @@ -304,7 +322,7 @@ def _split(self, dataset: Dataset) -> List[Dataset]: def setup(self) -> None: """Make assignments: val/train/test splits.""" - self._data_module.setup() + super().setup() train_data = self._data_module.train_data() self._train_data = self._split(train_data)[self._chunk_id] val_data = self._data_module.val_data() @@ -324,12 +342,12 @@ class FeatureSortingScenario(_SortingScenario): the features. Args: - data_module: The source RenateDataModule for the the user data. + data_module: The source RenateDataModule for the user data. num_tasks: The total number of expected tasks for experimentation. feature_idx: Index of the feature by which to sort. This index refers to the input features `x` of a single data point, i.e., no batch dimension. If the tensor `x` has more than one dimension, this indexes along the 0-dim while additional dimensions will be averaged - out. Hence, for images, `feature_idx` refers to a color channel and we sort by mean + out. Hence, for images, `feature_idx` refers to a color channel, and we sort by mean color channel value. randomness: A value between 0 and 1. For a dataset with ``N`` data points, ``0.5 * N * randomness`` random pairs are swapped. @@ -388,7 +406,7 @@ class WildTimeScenario(Scenario): the test set is all data up to the current time step. Args: - data_module: The source RenateDataModule for the the user data. + data_module: The source RenateDataModule for the user data. num_tasks: The total number of expected tasks for experimentation. chunk_id: The data chunk to load in for the training or validation data. seed: Seed used to fix random number generation. @@ -408,7 +426,7 @@ def __init__( def setup(self) -> None: """Sets up the scenario.""" self._data_module.time_step = self._chunk_id - self._data_module.setup() + super().setup() self._train_data = self._data_module.train_data() self._val_data = self._data_module.val_data() self._test_data = [] diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index c1281beb..bc5b9f99 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -177,6 +177,8 @@ def run(self): model_updater.update( train_dataset=data_module.train_data(), val_dataset=data_module.val_data(), + train_dataset_collate_fn=data_module.train_collate_fn(), + val_dataset_collate_fn=data_module.val_collate_fn(), task_id=args.task_id, ) diff --git a/src/renate/data/data_module.py b/src/renate/data/data_module.py index da792b3e..436c790c 100644 --- a/src/renate/data/data_module.py +++ b/src/renate/data/data_module.py @@ -3,7 +3,7 @@ import abc import os from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import pandas as pd import torch @@ -55,6 +55,9 @@ def __init__( self._train_data: Optional[Dataset] = None self._val_data: Optional[Dataset] = None self._test_data: Optional[Dataset] = None + self._train_collate_fn: Optional[Callable] = None + self._val_collate_fn: Optional[Callable] = None + self._test_collate_fn: Optional[Callable] = None assert 0.0 <= val_size <= 1.0 self._val_size = val_size self._seed = seed @@ -83,6 +86,18 @@ def test_data(self) -> Dataset: """Returns test dataset.""" return self._test_data + def train_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for train DataLoader.""" + return self._train_collate_fn + + def val_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for validation DataLoader.""" + return self._val_collate_fn + + def test_collate_fn(self) -> Optional[Callable]: + """Returns collate_fn for test DataLoader.""" + return self._test_collate_fn + def _verify_file(self, file_name: str) -> bool: """A helper function that verifies that the required dataset files are downloaded and correct. diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 5c6208f6..d8a16948 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -43,7 +43,6 @@ FRAMEWORK_VERSION = "1.12.0" TASK_ID = "default_task" -SUPPORTED_TASKS_TYPE = Literal["classification", "regression"] WORKING_DIRECTORY = "renate_working_dir" LOGGER = TensorBoardLogger LOGGER_KWARGS = { diff --git a/src/renate/evaluation/evaluator.py b/src/renate/evaluation/evaluator.py index 743040a1..c085def6 100644 --- a/src/renate/evaluation/evaluator.py +++ b/src/renate/evaluation/evaluator.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import abc -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch import torchmetrics @@ -11,7 +11,6 @@ from renate import defaults from renate.data.datasets import _TransformedDataset -from renate.evaluation.metrics.utils import create_metrics from renate.models import RenateModule from renate.utils.distributed_strategies import create_strategy from renate.utils.misc import int_or_str @@ -26,7 +25,6 @@ class Evaluator(LightningModule, abc.ABC): Args: model: A `RenateModule` to be evaluated. - task: The machine learning problem considered. batch_size: The batch size to be used when creating the test data loader. transform: The transformation applied for evaluation. target_transform: The target transformation applied for evaluation. @@ -36,7 +34,6 @@ class Evaluator(LightningModule, abc.ABC): def __init__( self, model: RenateModule, - task: defaults.SUPPORTED_TASKS_TYPE, batch_size: int, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, @@ -48,10 +45,13 @@ def __init__( self._batch_size = batch_size self._transform = transform self._target_transform = target_transform - self._metric_collection = create_metrics(task=task, additional_metrics=logged_metrics) + self._metric_collection = torchmetrics.MetricCollection(logged_metrics) def on_model_test_start( - self, test_dataset: Dataset, task_id: Optional[str] = None + self, + test_dataset: Dataset, + test_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> DataLoader: """Called before a model test starts.""" test_dataset = _TransformedDataset( @@ -60,7 +60,13 @@ def on_model_test_start( target_transform=self._target_transform, ) self._task_id = task_id - return DataLoader(test_dataset, batch_size=self._batch_size, shuffle=False, pin_memory=True) + return DataLoader( + test_dataset, + batch_size=self._batch_size, + shuffle=False, + pin_memory=True, + collate_fn=test_collate_fn, + ) def test_step(self, batch: List[torch.Tensor], batch_idx: int) -> None: """PyTorch Lightning function to perform the test step.""" @@ -91,9 +97,6 @@ class ClassificationEvaluator(Evaluator): dataset. """ - def __init__(self, **kwargs: Any): - super().__init__(task="classification", **kwargs) - def forward(self, x, task_id: Optional[str] = None) -> torch.Tensor: """Forward pass of the model. @@ -108,6 +111,7 @@ def forward(self, x, task_id: Optional[str] = None) -> torch.Tensor: def evaluate( model: RenateModule, test_dataset: Union[List[Dataset], Dataset], + test_collate_fn: Optional[Callable] = None, task_id: Union[List[str], str] = defaults.TASK_ID, batch_size: int = defaults.BATCH_SIZE, transform: Optional[Callable] = None, @@ -130,6 +134,7 @@ def evaluate( Args: model: A `RenateModule` to be evaluated. test_dataset: The test dataset(s) to be evaluated. + test_collate_fn: collate_fn used in the DataLoader. task_id: The task id(s) of the test dataset(s). batch_size: The batch size to be used when creating the test data loader. transform: The transformation applied for evaluation. @@ -140,6 +145,10 @@ def evaluate( devices: Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be "auto" and fetch the `auto_device_count` from the `accelerator`. + strategy: Name of the distributed training strategy to use. + `More details `__ + precision: Type of bit precision to use. + `More details `__ """ if isinstance(test_dataset, Dataset): test_dataset = [test_dataset] @@ -169,7 +178,7 @@ def evaluate( results = {} for i in range(len(test_dataset)): - test_loader = evaluator.on_model_test_start(test_dataset[i], task_id[i]) + test_loader = evaluator.on_model_test_start(test_dataset[i], test_collate_fn, task_id[i]) trainer.test( evaluator, test_loader, diff --git a/src/renate/evaluation/metrics/utils.py b/src/renate/evaluation/metrics/utils.py deleted file mode 100644 index 09780807..00000000 --- a/src/renate/evaluation/metrics/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional - -import torchmetrics - -import renate.defaults as defaults - - -def create_metrics( - task: defaults.SUPPORTED_TASKS_TYPE, - additional_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, -) -> torchmetrics.MetricCollection: - """Creates task-specific metrics including all additional metrics. - - Args: - task: Whether classification or regression, for now. - additional_metrics: Dictionary of additionally metrics to be added to the returned - `MetricCollection`. - """ - if task == "classification": - metric_collection = { - "accuracy": torchmetrics.Accuracy(), - } - elif task == "regression": - metric_collection = {"mean_squared_error": torchmetrics.MeanSquaredError()} - else: - raise NotImplementedError(f"Task {task} not implemented.") - if additional_metrics: - assert set(metric_collection).isdisjoint(set(additional_metrics)), ( - "Use a different name for your custom metrics. Following names are reserved for the " - f"default metrics: {set(metric_collection)}." - ) - metric_collection.update(additional_metrics) - return torchmetrics.MetricCollection(metric_collection) diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 3e33917d..f2d876b9 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -127,7 +127,7 @@ def run_training_job( max_num_trials_finished: Stopping criterion: trials finished. max_cost: (SageMaker backend only) Stopping criterion: SageMaker cost. n_workers: Number of workers running in parallel. - scheduler: Default is random search, you can change it by providing a either a string + scheduler: Default is random search, you can change it by providing either a string (`random`, `bo`, `asha` or `rush`) or scheduler class and its corresponding `scheduler_kwargs` if required. For latter option, `see details at `_ . @@ -136,7 +136,9 @@ def run_training_job( accelerator: Type of accelerator to use. devices: Number of devices to use per worker (set in n_workers). strategy: Name of the distributed training strategy to use. + `More details `__ precision: Type of bit precision to use. + `More details `__ deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. job_name: Prefix for the name of the SageMaker training job. """ diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index bd92bef0..92eda453 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -165,10 +165,14 @@ def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> RenateModule: val_dataset_exists = val_dataset is not None - benchmark = self._load_benchmark_if_exists(train_dataset, val_dataset) + benchmark = self._load_benchmark_if_exists( + train_dataset, val_dataset, train_dataset_collate_fn, val_dataset_collate_fn + ) train_exp = benchmark.train_stream[0] self._learner.train(train_exp, eval_streams=[benchmark.test_stream]) results = self._learner.eval(benchmark.test_stream) @@ -193,9 +197,13 @@ def update( return self._model def _load_benchmark_if_exists( - self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None + self, + train_dataset: Dataset, + val_dataset: Optional[Dataset] = None, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, ) -> AvalancheBenchmarkWrapper: - train_dataset = to_avalanche_dataset(train_dataset) + train_dataset = to_avalanche_dataset(train_dataset, train_dataset_collate_fn) avalanche_state = None if self._input_state_folder is not None: @@ -209,9 +217,11 @@ def _load_benchmark_if_exists( self._dummy_learner.load(self._input_state_folder) if val_dataset is not None: self._dummy_learner._val_memory_buffer.update(val_dataset) - val_memory_dataset = to_avalanche_dataset(self._dummy_learner._val_memory_buffer) + val_memory_dataset = to_avalanche_dataset( + self._dummy_learner._val_memory_buffer, val_dataset_collate_fn + ) else: - val_memory_dataset = to_avalanche_dataset(train_dataset) + val_memory_dataset = to_avalanche_dataset(train_dataset, val_dataset_collate_fn) benchmark = AvalancheBenchmarkWrapper( train_dataset=train_dataset, diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index 3005a18a..e32a06c1 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -79,10 +79,21 @@ def _create_metrics_collections( self._loss_collections["train_losses"].update({name: torchmetrics.MeanMetric()}) def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> None: """Called before a model update starts.""" - super().on_model_update_start(train_dataset, val_dataset, task_id) + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._set_memory_loader() def train_dataloader(self) -> DataLoader: diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index e978d023..71379b3d 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -63,10 +63,21 @@ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self._memory_buffer.load_state_dict(checkpoint["memory_buffer"]) def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> None: """Called before a model update starts.""" - super().on_model_update_start(train_dataset, val_dataset, task_id) + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._memory_buffer.update(train_dataset) reinitialize_model_parameters(self._model) diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index 179b9667..c21ca9fa 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -56,10 +56,21 @@ def load(self, input_state_dir: str) -> None: self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer")) def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> None: """Called before a model update starts.""" - super().on_model_update_start(train_dataset, val_dataset, task_id) + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._memory_buffer.update(train_dataset) reinitialize_model_parameters(self._model) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 4836015d..49b03bda 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -59,10 +59,21 @@ def _create_metrics_collections( self._loss_collections["train_losses"]["memory_loss"] = torchmetrics.MeanMetric() def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> None: """Called before a model update starts.""" - super().on_model_update_start(train_dataset, val_dataset, task_id) + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._num_points_current_task = len(train_dataset) def train_dataloader(self) -> DataLoader: diff --git a/src/renate/updaters/experimental/repeated_distill.py b/src/renate/updaters/experimental/repeated_distill.py index edfba727..6a127f3c 100644 --- a/src/renate/updaters/experimental/repeated_distill.py +++ b/src/renate/updaters/experimental/repeated_distill.py @@ -158,6 +158,8 @@ def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> RenateModule: """Updates the model using the data passed as input. @@ -165,6 +167,10 @@ def update( Args: train_dataset: The training data. val_dataset: The validation data. + train_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the training data. + val_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the validation data. task_id: The task id. """ # First, train a copy of the model on the new data from scratch as an expert model. We use @@ -178,7 +184,13 @@ def update( train_target_transform=self._train_target_transform, **{key: value for key, value in self._learner_kwargs.items() if key != "memory_size"}, ) - expert_learner.on_model_update_start(train_dataset, val_dataset, task_id) + expert_learner.on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._fit_learner(expert_learner) # Extract logits from the expert model and register them with the consolidation learner. @@ -194,7 +206,13 @@ def update( del expert_learner # Run consolidation. - self._learner.on_model_update_start(train_dataset, val_dataset, task_id) + self._learner.on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._fit_learner(self._learner) return self._model @@ -211,10 +229,21 @@ def update_expert_logits(self, new_expert_logits: torch.Tensor) -> None: self._expert_logits = new_expert_logits def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[int] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[int] = None, ) -> None: """Called before a model update starts.""" - super().on_model_update_start(train_dataset, val_dataset, task_id) + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._memory_buffer.update(train_dataset, metadata={"logits": self._expert_logits}) reinitialize_model_parameters(self._model) self._expert_logits = None diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index 7c943f3e..b84b27ea 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -17,26 +17,22 @@ from renate import defaults from renate.data.datasets import _TransformedDataset -from renate.evaluation.metrics.utils import create_metrics from renate.memory import DataBuffer, InfiniteBuffer, ReservoirBuffer from renate.models import RenateModule from renate.types import NestedTensors from renate.utils.pytorch import get_generator -class Learner(LightningModule, abc.ABC): - """Base class for Learners, which encapsulate the core CL methodologies. +class RenateLightningModule(LightningModule, abc.ABC): + """Base class for LightningModules, which implement metric logging and basic training logic. - The `Learner` is a `LightningModule`, but provides additional hook functions + The `RenateLightningModule` is a `LightningModule`, but provides additional hook functions called by `ModelUpdater`. These hooks are: - - `Learner.on_model_update_start`, which is called in the beginning of a + - `on_model_update_start`, which is called in the beginning of a model update. We expect this to return train and (optionally) validation data loader(s). - - `Learner.on_model_update_end`, which is called in the end of a model update. - - This base class implements a basic training loop without any mechanism to - counteract forgetting. + - `on_model_update_end`, which is called in the end of a model update. Args: model: The model to be trained. @@ -46,10 +42,6 @@ class Learner(LightningModule, abc.ABC): learning_rate_scheduler_interval: When to update the learning rate scheduler. Options: `epoch` and `step`. batch_size: Training batch size. - train_transform: The transformation applied during training. - train_target_transform: The target transformation applied during testing. - test_transform: The transformation at test time. - test_target_transform: The target transformation at test time. logged_metrics: Metrics logged additional to the default ones. seed: See :func:`renate.models.utils.get_generator`. """ @@ -62,10 +54,6 @@ def __init__( learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, - train_transform: Optional[Callable] = None, - train_target_transform: Optional[Callable] = None, - test_transform: Optional[Callable] = None, - test_target_transform: Optional[Callable] = None, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, seed: int = defaults.SEED, ) -> None: @@ -76,41 +64,35 @@ def __init__( self._learning_rate_scheduler = learning_rate_scheduler self._learning_rate_scheduler_interval = learning_rate_scheduler_interval self._batch_size = batch_size - self._train_transform = train_transform - self._train_target_transform = train_target_transform - self._test_transform = test_transform - self._test_target_transform = test_target_transform self._seed = seed self._task_id: str = defaults.TASK_ID + self._train_dataset: Optional[Dataset] = None + self._val_dataset: Optional[Dataset] = None + self.val_enabled = False + self._train_collate_fn: Optional[Callable] = None + self._val_collate_fn: Optional[Callable] = None - self._val_memory_buffer: DataBuffer = InfiniteBuffer() self._create_metrics_collections(logged_metrics) self._rng = get_generator(self._seed) - self.save_hyperparameters( - ignore=[ - "model", - "loss_fn", - "optimizer", - "learning_rate_scheduler", - "components", - "train_transform", - "test_transform", - "buffer_transform", - "train_transform", - "train_target_transform", - "test_transform", - "test_target_transform", - "buffer_transform", - "buffer_target_transform", - "logged_metrics", - ] - ) + self.save_hyperparameters(ignore=self._ignored_hyperparameters()) + + def _ignored_hyperparameters(self): + """Hyperparameters to be ignored in the ``save_hyperparameters`` call.""" + return [ + "model", + "loss_fn", + "optimizer", + "learning_rate_scheduler", + "logged_metrics", + ] def _create_metrics_collections( self, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None ) -> None: """Creates all logged metrics.""" - metrics = create_metrics(task="classification", additional_metrics=logged_metrics) + if logged_metrics is None: + logged_metrics = {} + metrics = torchmetrics.MetricCollection(logged_metrics) train_metrics = metrics.clone(prefix="train_") val_metrics = metrics.clone(prefix="val_") @@ -135,24 +117,6 @@ def _create_metrics_collections( } ) - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - learner_state_dict = { - "learner_class_name": self.__class__.__name__, - "val_memory_buffer": self._val_memory_buffer.state_dict(), - } - checkpoint.update(learner_state_dict) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]): - self._val_memory_buffer.load_state_dict(checkpoint["val_memory_buffer"]) - - def save(self, output_state_dir: str) -> None: - val_buffer_dir = os.path.join(output_state_dir, "val_memory_buffer") - os.makedirs(val_buffer_dir, exist_ok=True) - self._val_memory_buffer.save(val_buffer_dir) - - def load(self, input_state_dir: str) -> None: - self._val_memory_buffer.load(os.path.join(input_state_dir, "val_memory_buffer")) - def is_logged_metric(self, metric_name: str) -> bool: """Returns `True` if there is a metric with name `metric_name`.""" if metric_name is None: @@ -171,48 +135,42 @@ def is_logged_metric(self, metric_name: str) -> bool: return metric_name in logged_metrics def on_model_update_start( - self, train_dataset: Dataset, val_dataset: Dataset, task_id: Optional[str] = None + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, ) -> None: self._train_dataset = train_dataset self._val_dataset = val_dataset self.val_enabled = val_dataset is not None and len(val_dataset) + self._train_collate_fn = train_dataset_collate_fn + self._val_collate_fn = val_dataset_collate_fn self._task_id = task_id self._model.add_task_params(task_id=self._task_id) def train_dataloader(self) -> DataLoader: """Returns the dataloader for training the model.""" - train_dataset = _TransformedDataset( - self._train_dataset, - transform=self._train_transform, - target_transform=self._train_target_transform, - ) return DataLoader( - train_dataset, + self._train_dataset, batch_size=self._batch_size, shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) - def val_dataloader(self) -> DataLoader: + def val_dataloader(self) -> Optional[DataLoader]: if self._val_dataset is not None: - val_dataset = _TransformedDataset( - self._val_dataset, - transform=self._test_transform, - target_transform=self._test_target_transform, - ) - self._val_memory_buffer.update(val_dataset) - - if len(self._val_memory_buffer): return DataLoader( - self._val_memory_buffer, + self._val_dataset, batch_size=self._batch_size, shuffle=False, generator=self._rng, pin_memory=True, + collate_fn=self._val_collate_fn, ) - else: - return None def on_model_update_end(self) -> None: """Called right before a model update terminates.""" @@ -224,11 +182,15 @@ def forward(self, inputs: NestedTensors, task_id: Optional[str] = None) -> torch task_id = self._task_id return self._model(inputs, task_id=task_id) + def training_step_unpack_batch(self, batch: Tuple[Any, Any]) -> Tuple[Any, Any]: + inputs, targets = batch + return inputs, targets + def training_step( self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int ) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" - inputs, targets = batch + inputs, targets = self.training_step_unpack_batch(batch) outputs = self(inputs) intermediate_representation = self._model.get_intermediate_representation() self._model.reset_intermediate_representation_cache() @@ -253,9 +215,13 @@ def training_epoch_end(self, outputs: List[Union[Tensor, Dict[str, Any]]]) -> No if not self.val_enabled: self._log_metrics() + def validation_step_unpack_batch(self, batch: Tuple[Tuple[Any, Any], Any]) -> Tuple[Any, Any]: + (inputs, targets), _ = batch + return inputs, targets + def validation_step(self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int) -> None: """PyTorch Lightning function to estimate validation metrics.""" - (inputs, targets), _ = batch + inputs, targets = self.validation_step_unpack_batch(batch) outputs = self(inputs) loss = self._loss_fn(outputs, targets) self._update_metrics(outputs, targets, "val") @@ -317,6 +283,156 @@ def _log_metrics( loss.reset() +class Learner(RenateLightningModule, abc.ABC): + """Base class for Learners, which encapsulate the core CL methodologies. + + The `Learner` is a `LightningModule`, but provides additional hook functions + called by `ModelUpdater`. These hooks are: + + - `Learner.on_model_update_start`, which is called in the beginning of a + model update. We expect this to return train and (optionally) validation + data loader(s). + - `Learner.on_model_update_end`, which is called in the end of a model update. + + This base class implements a basic training loop without any mechanism to + counteract forgetting. + + Args: + model: The model to be trained. + optimizer: Partial optimizer used to create an optimizer by passing the model parameters. + learning_rate_scheduler: Partial object of learning rate scheduler that will be created by + passing the optimizer. + learning_rate_scheduler_interval: When to update the learning rate scheduler. + Options: `epoch` and `step`. + batch_size: Training batch size. + train_transform: The transformation applied during training. + train_target_transform: The target transformation applied during testing. + test_transform: The transformation at test time. + test_target_transform: The target transformation at test time. + logged_metrics: Metrics logged additional to the default ones. + seed: See :func:`renate.models.utils.get_generator`. + """ + + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + batch_size: int = defaults.BATCH_SIZE, + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + seed: int = defaults.SEED, + ) -> None: + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, + batch_size=batch_size, + logged_metrics=logged_metrics, + seed=seed, + ) + self._train_transform = train_transform + self._train_target_transform = train_target_transform + self._test_transform = test_transform + self._test_target_transform = test_target_transform + self._val_memory_buffer: DataBuffer = InfiniteBuffer() + + def _ignored_hyperparameters(self): + """Hyperparameters to be ignored in the ``save_hyperparameters`` call.""" + return super()._ignored_hyperparameters() + [ + "components", + "train_transform", + "train_target_transform", + "test_transform", + "test_target_transform", + "buffer_transform", + "buffer_target_transform", + ] + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + learner_state_dict = { + "learner_class_name": self.__class__.__name__, + "val_memory_buffer": self._val_memory_buffer.state_dict(), + } + checkpoint.update(learner_state_dict) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]): + self._val_memory_buffer.load_state_dict(checkpoint["val_memory_buffer"]) + + def save(self, output_state_dir: str) -> None: + val_buffer_dir = os.path.join(output_state_dir, "val_memory_buffer") + os.makedirs(val_buffer_dir, exist_ok=True) + self._val_memory_buffer.save(val_buffer_dir) + + def load(self, input_state_dir: str) -> None: + self._val_memory_buffer.load(os.path.join(input_state_dir, "val_memory_buffer")) + + def on_model_update_start( + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, + ) -> None: + super().on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) + self._model.add_task_params(task_id=self._task_id) + + def train_dataloader(self) -> DataLoader: + """Returns the dataloader for training the model.""" + train_dataset = _TransformedDataset( + self._train_dataset, + transform=self._train_transform, + target_transform=self._train_target_transform, + ) + return DataLoader( + train_dataset, + batch_size=self._batch_size, + shuffle=True, + generator=self._rng, + pin_memory=True, + collate_fn=self._train_collate_fn, + ) + + def val_dataloader(self) -> Optional[DataLoader]: + if self._val_dataset is not None: + val_dataset = _TransformedDataset( + self._val_dataset, + transform=self._test_transform, + target_transform=self._test_target_transform, + ) + self._val_memory_buffer.update(val_dataset) + + if len(self._val_memory_buffer): + return DataLoader( + self._val_memory_buffer, + batch_size=self._batch_size, + shuffle=False, + generator=self._rng, + pin_memory=True, + collate_fn=self._val_collate_fn, + ) + + def validation_step_unpack_batch( + self, batch: Tuple[NestedTensors, torch.Tensor] + ) -> Tuple[NestedTensors, Any]: + (inputs, targets), _ = batch + return inputs, targets + + class ReplayLearner(Learner, abc.ABC): """Base class for Learners which use a buffer to store data and reuse it in future updates. diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index b16df0f0..451440b0 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -340,6 +340,8 @@ def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> None: """Updates the model using the data passed as input. @@ -347,6 +349,10 @@ def update( Args: train_dataset: The training data. val_dataset: The validation data. + train_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the training data. + val_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the validation data. task_id: The task id. """ @@ -424,6 +430,8 @@ def update( self, train_dataset: Dataset, val_dataset: Optional[Dataset] = None, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, task_id: Optional[str] = None, ) -> RenateModule: """Updates the model using the data passed as input. @@ -431,8 +439,18 @@ def update( Args: train_dataset: The training data. val_dataset: The validation data. + train_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the training data. + val_dataset_collate_fn: collate_fn used to merge a list of samples to form a + mini-batch of Tensors for the validation data. task_id: The task id. """ - self._learner.on_model_update_start(train_dataset, val_dataset, task_id) + self._learner.on_model_update_start( + train_dataset=train_dataset, + val_dataset=val_dataset, + train_dataset_collate_fn=train_dataset_collate_fn, + val_dataset_collate_fn=val_dataset_collate_fn, + task_id=task_id, + ) self._fit_learner(self._learner) return self._model diff --git a/src/renate/utils/avalanche.py b/src/renate/utils/avalanche.py index faa04439..42fb5984 100644 --- a/src/renate/utils/avalanche.py +++ b/src/renate/utils/avalanche.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from collections import Counter -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import torch from avalanche.benchmarks import dataset_benchmark @@ -17,10 +17,14 @@ class AvalancheDataset(Dataset): """A Dataset consumable by Avalanche updaters.""" - def __init__(self, inputs: NestedTensors, targets: List[int]): + def __init__( + self, inputs: NestedTensors, targets: List[int], collate_fn: Optional[Callable] = None + ): self._inputs = inputs self._targets = targets self.targets = torch.tensor(targets, dtype=torch.long) + if collate_fn is not None: + self.collate_fn = collate_fn def __len__(self) -> int: return len(self._targets) @@ -29,7 +33,9 @@ def __getitem__(self, idx) -> Tuple[Tensor, Tensor]: return self._inputs[idx], self._targets[idx] -def to_avalanche_dataset(dataset: Union[Dataset, DataBuffer]) -> AvalancheDataset: +def to_avalanche_dataset( + dataset: Union[Dataset, DataBuffer], collate_fn: Optional[Callable] = None +) -> AvalancheDataset: """Converts a DataBuffer or Dataset into an Avalanche-compatible Dataset.""" x_data, y_data = [], [] for i in range(len(dataset)): @@ -41,7 +47,7 @@ def to_avalanche_dataset(dataset: Union[Dataset, DataBuffer]) -> AvalancheDatase if not isinstance(y, int): y = y.item() y_data.append(y) - return AvalancheDataset(x_data, y_data) + return AvalancheDataset(x_data, y_data, collate_fn) class AvalancheBenchmarkWrapper: diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 5486933f..0b0c8cf3 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -51,6 +51,10 @@ def evaluate_and_record_results( devices: Devices used by PyTorch Lightning to train the model. If the devices flag is not defined, it will assume devices to be "auto" and fetch the `auto_device_count` from the `accelerator`. + strategy: Name of the distributed training strategy to use. + `More details `__ + precision: Type of bit precision to use. + `More details `__ """ data_module.setup() diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 0bb06f2f..c183b9eb 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -11,6 +11,7 @@ get_data_module, get_scenario, loss_fn, + metrics_fn, model_fn, models, train_transform, @@ -312,3 +313,7 @@ def test_prediction_strategy_is_correctly_set(model_name, updater): def test_loss_fn_returns_correct_reduction_type(): assert loss_fn("ER").reduction == "none" assert loss_fn("Avalanche-ER").reduction == "mean" + + +def test_metrics_fn_contains_accuracy(): + assert "accuracy" in metrics_fn() diff --git a/test/renate/renate_config_files/config.py b/test/renate/renate_config_files/config.py index 14b4a435..0ea96599 100644 --- a/test/renate/renate_config_files/config.py +++ b/test/renate/renate_config_files/config.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch +from torchmetrics import Accuracy from dummy_datasets import DummyTorchVisionDataModule from renate.benchmark.models.mlp import MultiLayerPerceptron @@ -34,3 +35,7 @@ def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: if updater.startswith("Avalanche-"): return torch.nn.CrossEntropyLoss() return torch.nn.CrossEntropyLoss(reduction="none") + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/test/renate/renate_config_files/config_custom_optimizer.py b/test/renate/renate_config_files/config_custom_optimizer.py index cdea6820..eac8a25f 100644 --- a/test/renate/renate_config_files/config_custom_optimizer.py +++ b/test/renate/renate_config_files/config_custom_optimizer.py @@ -1,12 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 from functools import partial -from typing import Callable, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple import torch from torch.nn import Parameter from torch.optim import Optimizer, SGD from torch.optim.lr_scheduler import StepLR, _LRScheduler +from torchmetrics import Accuracy from dummy_datasets import DummyTorchVisionDataModule from renate.benchmark.models.mlp import MultiLayerPerceptron @@ -41,3 +42,7 @@ def optimizer_fn() -> Callable[[List[Parameter]], Optimizer]: def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]: return partial(StepLR, step_size=10, gamma=0.1), "epoch" + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/test/renate/renate_config_files/config_scenario.py b/test/renate/renate_config_files/config_scenario.py index cad1efd1..77ec44cc 100644 --- a/test/renate/renate_config_files/config_scenario.py +++ b/test/renate/renate_config_files/config_scenario.py @@ -1,13 +1,13 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch +from torchmetrics import Accuracy from dummy_datasets import DummyTorchVisionDataModule from renate.benchmark.models.mlp import MultiLayerPerceptron -from renate.benchmark.scenarios import ClassIncrementalScenario -from renate.data.data_module import RenateDataModule +from renate.benchmark.scenarios import ClassIncrementalScenario, Scenario from renate.models import RenateModule @@ -23,8 +23,8 @@ def data_module_fn( chunk_id: Optional[int] = None, val_size: float = 0.0, seed: int = 0, - class_groupings: Tuple[Tuple[int]] = ((0, 1), (2, 3, 4)), -) -> RenateDataModule: + class_groupings: Tuple[Tuple[int, ...], ...] = ((0, 1), (2, 3, 4)), +) -> Scenario: data_module = DummyTorchVisionDataModule(transform=None, val_size=val_size, seed=seed) return ClassIncrementalScenario( data_module=data_module, @@ -37,3 +37,7 @@ def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: if updater.startswith("Avalanche-"): return torch.nn.CrossEntropyLoss() return torch.nn.CrossEntropyLoss(reduction="none") + + +def metrics_fn() -> Dict: + return {"accuracy": Accuracy()} diff --git a/test/renate/training/test_run_training.py b/test/renate/training/test_run_training.py index 082bd804..c4ccad10 100644 --- a/test/renate/training/test_run_training.py +++ b/test/renate/training/test_run_training.py @@ -78,7 +78,7 @@ def execute_job(): mode="max", config_space=config_space, metric="val_accuracy", - max_time=30, + max_time=35, scheduler=scheduler, ) diff --git a/test/renate/updaters/test_model_updater.py b/test/renate/updaters/test_model_updater.py index 7e541f91..9f978b57 100644 --- a/test/renate/updaters/test_model_updater.py +++ b/test/renate/updaters/test_model_updater.py @@ -68,7 +68,7 @@ def test_deterministic_updater(): @pytest.mark.parametrize("early_stopping_enabled", [True, False]) @pytest.mark.parametrize("use_val", [True, False]) -@pytest.mark.parametrize("metric_monitored", [None, "val_accuracy"]) +@pytest.mark.parametrize("metric_monitored", [None, "val_loss"]) @pytest.mark.parametrize("updater_type", ["DMC", "SimpleUpdater"]) def test_model_updater_with_early_stopping( use_val, early_stopping_enabled, metric_monitored, updater_type @@ -89,7 +89,7 @@ def test_model_updater_with_early_stopping( model_updater = RepeatedDistillationModelUpdater( model=model, loss_fn=pytest.helpers.get_loss_fn(), - optimizer=pytest.helpers.get_partial_optimizer(), + optimizer=pytest.helpers.get_partial_optimizer(lr=0.3), memory_size=50, max_epochs=max_epochs, early_stopping_enabled=early_stopping_enabled, diff --git a/test/renate/utils/test_metrics_utils.py b/test/renate/utils/test_metrics_utils.py deleted file mode 100644 index 4205fe47..00000000 --- a/test/renate/utils/test_metrics_utils.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: Apache-2.0 -import pytest -import torchmetrics - -from renate import defaults -from renate.evaluation.metrics.utils import create_metrics - - -@pytest.mark.parametrize( - "task", - defaults.SUPPORTED_TASKS_TYPE.__args__ + ("unsupported_task",), - ids=defaults.SUPPORTED_TASKS_TYPE.__args__ + ("unsupported_task",), -) -def test_create_metrics_without_additional_metrics(task): - """Tests all allowed cases and all edge cases when no additional metrics is provided. - - Edge cases: not supported task. - """ - if task == "unsupported_task": - with pytest.raises(NotImplementedError): - create_metrics(task=task) - else: - metric_collection = create_metrics(task=task) - assert metric_collection.prefix is None - - -@pytest.mark.parametrize( - "additional_metrics, is_duplicate", - [ - ( - { - "auc": torchmetrics.AUC(), - "calibration_error": torchmetrics.CalibrationError(), - }, - False, - ), - ( - { - "accuracy": torchmetrics.Accuracy(), - }, - True, - ), - ], - ids=["valid_metric_names", "duplicate_metric_names"], -) -def test_given_additional_metrics_create_metrics_adds_them_to_metric_collection( - additional_metrics, is_duplicate -): - """Passes if duplicate metric names raise an exception and additional metrics are added to the - metric collection. - - Case 1: Valid case. - Case 2: Metric name already used as part of the default metric collection. - """ - if is_duplicate: - with pytest.raises(AssertionError): - create_metrics(task="classification", additional_metrics=additional_metrics) - else: - metric_collection = create_metrics( - task="classification", additional_metrics=additional_metrics - ) - assert metric_collection.prefix is None - assert set(additional_metrics).issubset(set(metric_collection)) From 8790f7fabc51ebee645ce95a056871241ccfea50 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 26 Jun 2023 16:12:15 +0200 Subject: [PATCH 42/89] Fix Small Bug in Benchmarking Script and Add LR Scheduler to Experiment Config (#305) --- src/renate/benchmark/experiment_config.py | 23 +++ src/renate/benchmark/experimentation.py | 23 ++- src/renate/defaults.py | 2 +- src/renate/utils/file.py | 135 +++++++++++++----- test/renate/benchmark/test_experimentation.py | 23 ++- .../benchmark/test_experimentation_config.py | 24 ++++ test/renate/utils/test_file.py | 57 +++++++- 7 files changed, 220 insertions(+), 67 deletions(-) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index d5009af4..8854c39e 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -1,9 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from functools import partial from typing import Callable, Dict, List, Optional, Tuple, Union import torch import wild_time_data +from torch.optim import Optimizer +from torch.optim.lr_scheduler import StepLR, _LRScheduler from torchmetrics import Accuracy from torchvision.transforms import transforms from transformers import AutoTokenizer @@ -338,5 +341,25 @@ def test_transform(dataset_name: str) -> Optional[Callable]: raise ValueError(f"Unknown dataset `{dataset_name}`.") +def lr_scheduler_fn( + learning_rate_scheduler: Optional[str] = None, + learning_rate_scheduler_step_size: int = 30, + learning_rate_scheduler_gamma: float = 0.1, + learning_rate_scheduler_interval: str = "epoch", +) -> Tuple[Optional[Callable[[Optimizer], _LRScheduler]], str]: + if learning_rate_scheduler == "StepLR": + return ( + partial( + StepLR, + step_size=learning_rate_scheduler_step_size, + gamma=learning_rate_scheduler_gamma, + ), + learning_rate_scheduler_interval, + ) + elif learning_rate_scheduler is None: + return None, learning_rate_scheduler_interval + raise ValueError(f"Unknown scheduler `{learning_rate_scheduler}`.") + + def metrics_fn() -> Dict: return {"accuracy": Accuracy()} diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 474aef08..572dff31 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -144,7 +144,7 @@ def execute_experiment_job( job_name: str = defaults.JOB_NAME, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, - retain_intermediate_state: bool = defaults.RETAIN_INTERMEDIATE_STATE, + save_state: bool = defaults.SAVE_BENCHMARK_STATE, ) -> None: """Executes the experiment job. @@ -181,9 +181,8 @@ def execute_experiment_job( `More details `__ precision: Type of bit precision to use. `More details `__ - retain_intermediate_state: Flag to retain models and buffer states after each - task update. This is useful when training with large datasets that might cause storage - issues. + save_state: Flag to retain models and buffer states of each update step. Disable to save + storage. """ assert ( mode in defaults.SUPPORTED_TUNING_MODE @@ -211,7 +210,7 @@ def execute_experiment_job( seed=seed, strategy=strategy, precision=precision, - retain_intermediate_state=retain_intermediate_state, + save_state=save_state, ) _execute_experiment_job_remotely( job_name=job_name, @@ -239,7 +238,7 @@ def execute_experiment_job( instance_max_time=instance_max_time, strategy=strategy, precision=precision, - retain_intermediate_state=retain_intermediate_state, + save_state=save_state, ) @@ -262,7 +261,7 @@ def _execute_experiment_job_locally( deterministic_trainer: bool, strategy: str, precision: str, - retain_intermediate_state: bool, + save_state: bool, ) -> None: """Runs an experiment, combining hyperparameter tuning and model for multiple updates. @@ -310,7 +309,7 @@ def _execute_experiment_job_locally( # TODO: evaluate's trainer has to use devices=1: # See https://github.com/Lightning-AI/lightning/issues/2537 - # The fix is to launch evaluation in a seperate process like training. + # The fix is to launch evaluation in a separate process like training. results: Dict[str, List[List[float]]] = {} evaluate_and_record_results( results, @@ -356,7 +355,7 @@ def _execute_experiment_job_locally( deterministic_trainer=deterministic_trainer, ) move_to_uri(output_state_url, input_state_url) - if retain_intermediate_state: + if save_state: copy_to_uri(input_state_url, update_url) model = get_model( config_module, @@ -384,13 +383,11 @@ def _execute_experiment_job_locally( cumulative_metrics = create_cumulative_metrics() df = cumulative_metrics_summary(results, cumulative_metrics, num_updates - 1) save_pandas_df_to_csv(df, defaults.metric_summary_file(logs_url)) - if not retain_intermediate_state: - move_to_uri( - defaults.hpo_file(input_state_url), defaults.logs_folder(experiment_outputs_url) - ) logger.info("### Cumulative results: ###") logger.info(df) + if not save_state: + move_to_uri(defaults.hpo_file(input_state_url), str(experiment_outputs_url)) move_to_uri(logs_url, defaults.logs_folder(experiment_outputs_url)) shutil.rmtree(working_directory) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index d8a16948..8e193270 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -54,7 +54,7 @@ JOB_NAME = "renate" SUPPORTED_TUNING_MODE = ["min", "max"] SUPPORTED_TUNING_MODE_TYPE = Literal["min", "max"] -RETAIN_INTERMEDIATE_STATE = True +SAVE_BENCHMARK_STATE = True SUPPORTED_BACKEND = ["local", "sagemaker"] SUPPORTED_BACKEND_TYPE = Literal["local", "sagemaker"] diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index 9c19b43b..465c24ee 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -41,21 +41,30 @@ def _parse_s3_url(s3_url: str) -> Tuple[str, str]: def _move_locally( - source_dir: Union[str, Path], - destination_dir: Union[str, Path], + src: Union[str, Path], + dst: Union[str, Path], ignore_extensions: List[str] = [".sagemaker-uploading", ".sagemaker-uploaded"], copy: bool = False, ) -> None: - """Moves files to directory. If the files exist they are overwritten. + """Moves files in directory or file to directory. If the files exist they are overwritten. Args: - source_dir: Source directory. - destination_dir: Target directory. + src: Source directory or file. + dst: Target directory or file. ignore_extensions: List of extensions to ignore. copy: If `True`, copy instead of move. """ - for src_dir, _, files in os.walk(source_dir): - dst_dir = src_dir.replace(source_dir, destination_dir, 1) + if os.path.isfile(src): + os.makedirs(dst, exist_ok=True) + dst_file = os.path.join(dst, os.path.basename(src)) + if os.path.exists(dst_file): + os.remove(dst_file) + if copy: + shutil.copy(src, dst) + else: + shutil.move(src, dst) + for src_dir, _, files in os.walk(src): + dst_dir = src_dir.replace(src, dst, 1) if not os.path.exists(dst_dir): os.makedirs(dst_dir) for f in files: @@ -71,50 +80,89 @@ def _move_locally( shutil.move(src_f, dst_f) -def move_to_uri( - local_dir: Union[Path, str], - uri: str, +def _move_to_s3( + src: Union[str, Path], + dst: Union[str, Path], ignore_extensions: List[str] = [".sagemaker-uploading", ".sagemaker-uploaded"], + copy: bool = False, ) -> None: - """Moves files to directory or s3. If the files exist they are overwritten. - The files in the local directory are deleted. + """Moves files in directory or file to directory or s3. + + If the files exist they are overwritten. The files in the local directory are deleted. Args: - local_dir: Local directory to copy. - uri: Target directory or s3 uri. + src: Local file or directory to move. + dst: Target directory or s3 uri. ignore_extensions: List of extensions to ignore. + copy: If `True`, copy instead of move. """ - if is_s3_uri(uri): - upload_folder_to_s3(local_dir, uri, ignore_extensions=ignore_extensions) - for f in os.listdir(local_dir): - f = os.path.join(local_dir, f) - if os.path.isfile(f): - os.remove(f) - elif local_dir != uri: - _move_locally(local_dir, uri, ignore_extensions=ignore_extensions, copy=False) + if os.path.isfile(src): + dst_file = os.path.join(dst, os.path.basename(src)) + upload_file_to_s3(src, dst_file) + if not copy: + os.remove(src) else: - logging.warning(f"Source and destination are the same: {local_dir}") + upload_folder_to_s3(src, dst, ignore_extensions=ignore_extensions) + if not copy: + shutil.rmtree(src) -def copy_to_uri( - local_dir: Union[Path, str], - uri: str, +def _move_to_uri( + src: Union[Path, str], + dst: str, ignore_extensions: List[str] = [".sagemaker-uploading", ".sagemaker-uploaded"], + copy: bool = False, ) -> None: - """Copies files to directory or s3. If the files exist they are overwritten. - The files in the local directory are preserved. + """Moves files in directory or file to directory or s3. + + If the files exist they are overwritten. The files in the local directory are deleted. Args: - local_dir: Local directory to copy. - uri: Target directory or s3 uri. + src: Local file or directory to move. + dst: Target directory or s3 uri. ignore_extensions: List of extensions to ignore. + copy: If `True`, copy instead of move. """ - if is_s3_uri(uri): - upload_folder_to_s3(local_dir, uri, ignore_extensions=ignore_extensions) - elif local_dir != uri: - _move_locally(local_dir, uri, ignore_extensions=ignore_extensions, copy=True) + if is_s3_uri(dst): + _move_to_s3(src, dst, ignore_extensions=ignore_extensions, copy=copy) + elif src != dst: + _move_locally(src, dst, ignore_extensions=ignore_extensions, copy=copy) else: - logging.warning(f"Source and destination are the same: {local_dir}") + logging.warning(f"Source and destination are the same: {src}") + + +def move_to_uri( + src: Union[Path, str], + dst: str, + ignore_extensions: List[str] = [".sagemaker-uploading", ".sagemaker-uploaded"], +) -> None: + """Moves files in directory or file to directory or s3. + + If the files exist they are overwritten. The files in the local directory are deleted. + + Args: + src: Local file or directory to move. + dst: Target directory or s3 uri. + ignore_extensions: List of extensions to ignore. + """ + _move_to_uri(src=src, dst=dst, ignore_extensions=ignore_extensions, copy=False) + + +def copy_to_uri( + src: Union[Path, str], + dst: str, + ignore_extensions: List[str] = [".sagemaker-uploading", ".sagemaker-uploaded"], +) -> None: + """Copies files in directory or file to directory or s3. + + If the files exist they are overwritten. The files in the local directory are preserved. + + Args: + src: Local directory to copy. + dst: Target directory or s3 uri. + ignore_extensions: List of extensions to ignore. + """ + _move_to_uri(src=src, dst=dst, ignore_extensions=ignore_extensions, copy=True) def maybe_download_from_s3(url: str, local_dir: Union[Path, str]) -> str: @@ -157,7 +205,7 @@ def upload_folder_to_s3( Args: local_dir: Folder containing files to be uploaded. - s3_url: + s3_url: Full path to s3 location. dst_bucket: s3 bucket. prefix: Prefix for all s3 object names. ignore_extensions: List of extensions to ignore. @@ -174,7 +222,7 @@ def upload_folder_to_s3( continue file_path = os.path.join(current_folder, file_name) object_name = os.path.join(prefix, current_folder[len(local_dir) + 1 :], file_name) - upload_file_to_s3(file_path, dst_bucket, object_name) + upload_file_to_s3(file_path, dst_bucket=dst_bucket, dst_object_name=object_name) def download_file_from_s3( @@ -197,18 +245,27 @@ def download_file_from_s3( def upload_file_to_s3( - src: Union[Path, str], dst_bucket: str, dst_object_name: Union[Path, str] + src: Union[Path, str], + s3_url: Optional[Union[Path, str]] = None, + dst_bucket: Optional[str] = None, + dst_object_name: Optional[Union[Path, str]] = None, ) -> bool: """Upload a file to an S3 bucket Args: src: File to upload. + s3_url: Full path to s3 location. dst_bucket: Destination S3 bucket dst_object_name: Destination S3 object Return: True if file was uploaded, else False """ + assert ( + s3_url is not None or dst_bucket is not None and dst_object_name is not None + ), "Either pass s3_url or both dst_bucket and dst_object_name." + if s3_url is not None: + dst_bucket, dst_object_name = _parse_s3_url(s3_url) s3_client = boto3.client("s3") logger.info(f"Upload file from {src} to s3://{dst_bucket}/{dst_object_name}") try: @@ -282,7 +339,7 @@ def save_pandas_df_to_csv(df: pd.DataFrame, file_path: Union[str, Path]) -> pd.D @rank_zero_only def unlink_file_or_folder(path: Path) -> None: - """Funtion to remove files and folders. + """Function to remove files and folders. Unlink works for files, rmdir for empty folders, but not for non-empty ones. Hence a recursive solution. diff --git a/test/renate/benchmark/test_experimentation.py b/test/renate/benchmark/test_experimentation.py index 9ce75250..821b6c46 100644 --- a/test/renate/benchmark/test_experimentation.py +++ b/test/renate/benchmark/test_experimentation.py @@ -26,7 +26,8 @@ def experiment_job_kwargs(): } -def test_execute_experiment_job(tmpdir, experiment_job_kwargs): +@pytest.mark.parametrize("save_state", (True, False)) +def test_execute_experiment_job(tmpdir, experiment_job_kwargs, save_state): """Only checking if things run, not testing anything besides that.""" expected_columns = [ "Task ID", @@ -36,20 +37,18 @@ def test_execute_experiment_job(tmpdir, experiment_job_kwargs): "Backward Transfer", ] expected_num_updates = experiment_job_kwargs["num_updates"] + experiment_job_kwargs["save_state"] = save_state execute_experiment_job(experiment_outputs_url=tmpdir, **experiment_job_kwargs) results_df = pd.read_csv(str(Path(tmpdir) / "logs" / "metrics_summary.csv")) assert all(results_df.columns == expected_columns) - for update_id in range(expected_num_updates): - assert (Path(tmpdir) / f"update_{update_id}" / "learner.ckpt").is_file() - assert (Path(tmpdir) / f"update_{update_id}" / "model.ckpt").is_file() - assert ( - len( - pd.read_csv(str(Path(tmpdir) / f"update_{expected_num_updates - 1}" / "hpo.csv"))[ - "update_id" - ].unique() - ) - == expected_num_updates - ) + if save_state: + hpo_file = Path(tmpdir) / f"update_{expected_num_updates - 1}" / "hpo.csv" + for update_id in range(expected_num_updates): + assert (Path(tmpdir) / f"update_{update_id}" / "learner.ckpt").is_file() + assert (Path(tmpdir) / f"update_{update_id}" / "model.ckpt").is_file() + else: + hpo_file = Path(tmpdir) / "hpo.csv" + assert len(pd.read_csv(str(hpo_file))["update_id"].unique()) == expected_num_updates @pytest.mark.parametrize( diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index c183b9eb..1c1ebc87 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -1,6 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import pytest +from torch.nn import Linear +from torch.optim import SGD +from torch.optim.lr_scheduler import StepLR from torchvision.transforms import Compose, Normalize from renate.benchmark import experiment_config @@ -11,6 +14,7 @@ get_data_module, get_scenario, loss_fn, + lr_scheduler_fn, metrics_fn, model_fn, models, @@ -286,6 +290,26 @@ def test_transforms_fails_for_unknown_dataset(): transform_function(unknown_dataset_set) +@pytest.mark.parametrize( + "learning_rate_scheduler,expected_lr_class,expected_interval", + (("StepLR", StepLR, "epoch"), (None, None, "epoch")), +) +def test_lr_scheduler_fn(learning_rate_scheduler, expected_lr_class, expected_interval): + scheduler, interval = lr_scheduler_fn(learning_rate_scheduler) + assert interval == expected_interval + if learning_rate_scheduler is None: + assert scheduler is None + else: + scheduler = scheduler(SGD(Linear(1, 1).parameters(), lr=0.1)) + assert isinstance(scheduler, expected_lr_class) + + +def test_lr_scheduler_fn_fails_for_unknown_scheduler(): + unknown_lr_scheduler = "UNKNOWN_SCHEDULER_NAME" + with pytest.raises(ValueError, match=f"Unknown scheduler `{unknown_lr_scheduler}`."): + lr_scheduler_fn(unknown_lr_scheduler) + + @pytest.mark.parametrize("model_name", [model_name for model_name in models]) @pytest.mark.parametrize("updater", ("ER", "Avalanche-iCaRL")) def test_prediction_strategy_is_correctly_set(model_name, updater): diff --git a/test/renate/utils/test_file.py b/test/renate/utils/test_file.py index b94e2616..e9a7d8a6 100644 --- a/test/renate/utils/test_file.py +++ b/test/renate/utils/test_file.py @@ -18,7 +18,7 @@ def create_file(path, content): "file_content2_destination_dir", [["test1_1", "test1_1", "test1_2", "test2_2"]], ) -def test_move_to_uri_locally( +def test_move_to_uri_locally_directory( tmpdir, file_content1_source_dir, file_content2_source_dir, @@ -53,12 +53,42 @@ def test_move_to_uri_locally( assert not os.path.exists(os.path.join(tmpdir, "source_dir", "file2.txt")) +def test_move_to_uri_locally_file(tmpdir): + """Test for moving a file to another local directory. + + The file should be moved from the source directory and the destination directory should be + created if it does not exist. + + If there are files with the same name in the destination directory as in the source directory + they should be overwritten. + """ + + file_content = "content" + create_file(os.path.join(tmpdir, "source_dir", "file.txt"), file_content) + + move_to_uri( + os.path.join(tmpdir, "source_dir", "file.txt"), os.path.join(tmpdir, "destination_dir") + ) + with open(os.path.join(tmpdir, "destination_dir", "file.txt"), "r") as f: + assert f.read() == file_content + assert not os.path.exists(os.path.join(tmpdir, "source_dir", "file.txt")) + + file_content = "content2" + create_file(os.path.join(tmpdir, "source_dir", "file.txt"), file_content) + move_to_uri( + os.path.join(tmpdir, "source_dir", "file.txt"), os.path.join(tmpdir, "destination_dir") + ) + with open(os.path.join(tmpdir, "destination_dir", "file.txt"), "r") as f: + assert f.read() == file_content + assert not os.path.exists(os.path.join(tmpdir, "source_dir", "file.txt")) + + @pytest.mark.parametrize( "file_content1_source_dir, file_content2_source_dir, file_content1_destination_dir, " "file_content2_destination_dir", [["test1_1", "test1_1", "test1_2", "test2_2"]], ) -def test_copy_to_uri_locally( +def test_copy_to_uri_locally_directory( tmpdir, file_content1_source_dir, file_content2_source_dir, @@ -95,3 +125,26 @@ def test_copy_to_uri_locally( assert f.read() == file_content1_source_dir with open(os.path.join(tmpdir, "source_dir", "file2.txt"), "r") as f: assert f.read() == file_content2_source_dir + + +def test_copy_to_uri_locally_file(tmpdir): + """Test for copying a file from a local directory to another local directory. + + The file should be copied from the source directory and the destination directory should be + created if it does not exist. + + If there are files with the same name in the destination directory as in the source directory + they should be overwritten. + + The source directory should not be changed. + """ + file_content = "content" + create_file(os.path.join(tmpdir, "source_dir", "file.txt"), file_content) + + copy_to_uri( + os.path.join(tmpdir, "source_dir", "file.txt"), os.path.join(tmpdir, "destination_dir") + ) + with open(os.path.join(tmpdir, "destination_dir", "file.txt"), "r") as f: + assert f.read() == file_content + with open(os.path.join(tmpdir, "source_dir", "file.txt"), "r") as f: + assert f.read() == file_content From 59186366afe0bccdc7537dd554ddc6d81eb2fa9a Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 30 Jun 2023 13:23:46 +0200 Subject: [PATCH 43/89] Describe Installation of Dependencies for Benchmarking (#313) --- doc/benchmarking/index.rst | 7 +++++++ doc/getting_started/install.rst | 13 ++++++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/doc/benchmarking/index.rst b/doc/benchmarking/index.rst index 2debb4f8..afa750b0 100644 --- a/doc/benchmarking/index.rst +++ b/doc/benchmarking/index.rst @@ -14,6 +14,13 @@ At the core of this feature is the function :py:func:`~renate.benchmark.experime For the reader familiar with the function :py:func:`~renate.training.training.run_training_job`, the use will be very intuitive. +Renate's benchmarking functionality may require additional dependencies. +Please install them via + +.. code-block:: bash + + pip install Renate[benchmark] + In the following chapters, we will discuss how this interface can be used to experiment on :doc:`Renate benchmarks ` as well as :doc:`custom benchmarks `. diff --git a/doc/getting_started/install.rst b/doc/getting_started/install.rst index 32e264d7..13ad2f55 100644 --- a/doc/getting_started/install.rst +++ b/doc/getting_started/install.rst @@ -13,7 +13,18 @@ If you want to use additional methods that require the Avalanche library, please pip install Renate[avalanche] -For Renate contributors we recommend using ``dev``. +If you want to use Renate for :doc:`benchmarking <../benchmarking/index>`, please use + +.. code-block:: bash + + pip install Renate[benchmark] + +Renate contributors should use + +.. code-block:: bash + + pip install Renate[dev] + This will install further dependencies which are required for code formatting and unit testing. Alternatively, if you want to access the code directly (e.g., for developing and running new methods) From 91f26670f65d91c82f5676ae59459af3745948b7 Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 30 Jun 2023 18:20:57 +0200 Subject: [PATCH 44/89] Upgrade to torchmetrics 0.11 (#314) --- examples/getting_started/renate_config.py | 2 +- .../simple_classifier_cifar10/renate_config.py | 2 +- examples/train_mlp_locally/renate_config.py | 2 +- requirements.txt | 2 +- src/renate/benchmark/experiment_config.py | 12 +++++------- src/renate/benchmark/experimentation.py | 4 +++- src/renate/cli/parsing_functions.py | 9 +++++++++ src/renate/cli/run_training.py | 5 ++++- src/renate/evaluation/metrics/classification.py | 17 ++++++++++------- src/renate/utils/module.py | 6 ++++-- .../benchmark/test_experimentation_config.py | 9 ++++++--- test/renate/cli/test_parsing_functions.py | 6 ++++++ test/renate/renate_config_files/config.py | 4 ++-- .../config_custom_optimizer.py | 2 +- .../renate_config_files/config_scenario.py | 2 +- test/renate/updaters/test_model_updater.py | 2 ++ 16 files changed, 57 insertions(+), 29 deletions(-) diff --git a/examples/getting_started/renate_config.py b/examples/getting_started/renate_config.py index 9a1e883b..1b5e8456 100644 --- a/examples/getting_started/renate_config.py +++ b/examples/getting_started/renate_config.py @@ -93,7 +93,7 @@ def buffer_transform() -> Callable: def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} + return {"accuracy": Accuracy(task="multiclass", num_classes=10)} def loss_fn() -> torch.nn.Module: diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 4cf5883a..1594dcac 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -69,4 +69,4 @@ def loss_fn() -> torch.nn.Module: def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} + return {"accuracy": Accuracy(task="multiclass", num_classes=10)} diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index 58b5f42a..decf622d 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -56,4 +56,4 @@ def loss_fn() -> torch.nn.Module: def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} + return {"accuracy": Accuracy(task="multiclass", num_classes=10)} diff --git a/requirements.txt b/requirements.txt index c35059b5..4c011b0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,7 +9,7 @@ pytorch-lightning>=1.8.0, <1.9.5 Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 tensorboardX>=2.5.0, <2.5.2 -torchmetrics~=0.10.3 +torchmetrics>=0.11.0, <0.11.5 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 datasets>=2.9.0, <2.12.1 diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 8854c39e..7164c556 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -7,7 +7,7 @@ import wild_time_data from torch.optim import Optimizer from torch.optim.lr_scheduler import StepLR, _LRScheduler -from torchmetrics import Accuracy +from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import transforms from transformers import AutoTokenizer @@ -64,11 +64,11 @@ def model_fn( + num_outputs: int, model_state_url: Optional[str] = None, updater: Optional[str] = None, model_name: Optional[str] = None, num_inputs: Optional[int] = None, - num_outputs: Optional[int] = None, num_hidden_layers: Optional[int] = None, hidden_size: Optional[Tuple[int]] = None, dataset_name: Optional[str] = None, @@ -78,7 +78,7 @@ def model_fn( if model_name not in models: raise ValueError(f"Unknown model `{model_name}`") model_class = models[model_name] - model_kwargs = {} + model_kwargs = {"num_outputs": num_outputs} if updater == "Avalanche-iCaRL": model_kwargs["prediction_strategy"] = ICaRLClassificationStrategy() if model_name == "MultiLayerPerceptron": @@ -95,8 +95,6 @@ def model_fn( if updater == "Avalanche-iCaRL": raise ValueError("Transformers do not support iCaRL.") model_kwargs["pretrained_model_name"] = pretrained_model_name - if num_outputs is not None: - model_kwargs["num_outputs"] = num_outputs if model_state_url is None: model = model_class(**model_kwargs) else: @@ -361,5 +359,5 @@ def lr_scheduler_fn( raise ValueError(f"Unknown scheduler `{learning_rate_scheduler}`.") -def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} +def metrics_fn(num_outputs: int) -> Dict: + return {"accuracy": MulticlassAccuracy(num_classes=num_outputs, average="micro")} diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 572dff31..c2e90fc9 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -14,6 +14,7 @@ import renate.defaults as defaults from renate.cli.parsing_functions import ( get_data_module_fn_kwargs, + get_metrics_fn_kwargs, get_model_fn_kwargs, get_scheduler_kwargs, get_transforms_kwargs, @@ -300,7 +301,8 @@ def _execute_experiment_job_locally( data_module.test_data() ), f"The dataset has {len(data_module.test_data())} chunks, expected {num_updates}." transforms = get_transforms_kwargs(config_module, config_space) - metrics = get_metrics(config_module) + metrics_fn_kwargs = get_metrics_fn_kwargs(config_module, config_space) + metrics = get_metrics(config_module, **metrics_fn_kwargs) torch.save( model.state_dict(), diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index a3b55fee..7b713c51 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -741,6 +741,15 @@ def get_transforms_kwargs( return transforms +def get_metrics_fn_kwargs( + config_module, config_space: Dict[str, Any], cast_arguments: Optional[bool] = False +) -> Dict[str, Any]: + """Returns the kwargs for a ``metrics_fn`` with defined arguments based on config_space.""" + return _get_function_kwargs_helper( + config_module, config_space, "metrics_fn", [], cast_arguments + ) + + def _get_function_kwargs_helper( config_module, config_space: Dict[str, Any], diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index bc5b9f99..bb08f050 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -149,7 +149,10 @@ def run(self): if lr_scheduler_config is not None: lr_scheduler_kwargs["learning_rate_scheduler"] = lr_scheduler_config[0] lr_scheduler_kwargs["learning_rate_scheduler_interval"] = lr_scheduler_config[1] - metrics = get_metrics(config_module) + metrics = get_metrics( + config_module, + **get_function_kwargs(args=args, function_args=function_args["metrics_fn"]), + ) model_updater_class, learner_kwargs = get_updater_and_learner_kwargs(args) diff --git a/src/renate/evaluation/metrics/classification.py b/src/renate/evaluation/metrics/classification.py index dee911dc..dde5fc5e 100644 --- a/src/renate/evaluation/metrics/classification.py +++ b/src/renate/evaluation/metrics/classification.py @@ -47,7 +47,7 @@ def forgetting(results: Dict[str, List[List[float]]], task_id: int) -> float: if task_id == 0: return 0.0 - def f(results: Dict[str, List[List[float]]], j: int, i: int) -> float: + def f(results: List[List[float]], j: int, i: int) -> float: """A Helper function to compute the: math:`f_{j,i}`.""" accuracy_ji = results[j][i] max_accuracy_ki = max([results[k][i] for k in range(j)]) @@ -102,9 +102,12 @@ def forward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> flo """ if task_id == 0: return 0.0 - return sum( - [ - results["accuracy"][i - 1][i] - results["accuracy_init"][0][i] - for i in range(1, task_id + 1) - ] - ) / (task_id) + return ( + sum( + [ + results["accuracy"][i - 1][i] - results["accuracy_init"][0][i] + for i in range(1, task_id + 1) + ] + ) + / task_id + ) diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 0b0c8cf3..133f6830 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -127,11 +127,13 @@ def get_learning_rate_scheduler( return getattr(config_module, lr_scheduler_fn_name)(**kwargs) -def get_metrics(config_module: ModuleType) -> Optional[Dict[str, torchmetrics.Metric]]: +def get_metrics( + config_module: ModuleType, **kwargs: Any +) -> Optional[Dict[str, torchmetrics.Metric]]: """Creates and returns a dictionary of metrics.""" metrics_fn_name = "metrics_fn" if metrics_fn_name in vars(config_module): - return getattr(config_module, metrics_fn_name)() + return getattr(config_module, metrics_fn_name)(**kwargs) def get_and_prepare_data_module(config_module: ModuleType, **kwargs: Any) -> RenateDataModule: diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 1c1ebc87..2ad69e8a 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -4,6 +4,7 @@ from torch.nn import Linear from torch.optim import SGD from torch.optim.lr_scheduler import StepLR +from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import Compose, Normalize from renate.benchmark import experiment_config @@ -42,7 +43,7 @@ def test_model_fn(model_name, expected_model_class): model_state_url=None, model_name=model_name, num_inputs=1 if model_name == "MultiLayerPerceptron" else None, - num_outputs=1 if model_name in ["MultiLayerPerceptron", "HuggingFaceTransformer"] else None, + num_outputs=2, num_hidden_layers=1 if model_name == "MultiLayerPerceptron" else None, hidden_size=1 if model_name == "MultiLayerPerceptron" else None, pretrained_model_name="distilbert-base-uncased" @@ -66,6 +67,7 @@ def test_model_fn(model_name, expected_model_class): def test_model_fn_automatic_input_channel_detection_resnet(dataset_name, expected_in_channels): """Tests if ResNet architectures input channels are correctly adapted to the dataset.""" model = model_fn( + num_outputs=10, model_state_url=None, model_name="ResNet18", dataset_name=dataset_name, @@ -76,7 +78,7 @@ def test_model_fn_automatic_input_channel_detection_resnet(dataset_name, expecte def test_model_fn_fails_for_unknown_model(): unknown_model_name = "UNKNOWN_MODEL_NAME" with pytest.raises(ValueError, match=f"Unknown model `{unknown_model_name}`"): - model_fn(model_name=unknown_model_name) + model_fn(num_outputs=10, model_name=unknown_model_name) @pytest.mark.parametrize( @@ -340,4 +342,5 @@ def test_loss_fn_returns_correct_reduction_type(): def test_metrics_fn_contains_accuracy(): - assert "accuracy" in metrics_fn() + assert isinstance(metrics_fn(num_outputs=2)["accuracy"], MulticlassAccuracy) + assert isinstance(metrics_fn(num_outputs=10)["accuracy"], MulticlassAccuracy) diff --git a/test/renate/cli/test_parsing_functions.py b/test/renate/cli/test_parsing_functions.py index 8071f213..24db77ea 100644 --- a/test/renate/cli/test_parsing_functions.py +++ b/test/renate/cli/test_parsing_functions.py @@ -13,6 +13,7 @@ get_argument_type, get_data_module_fn_kwargs, get_function_args, + get_metrics_fn_kwargs, get_model_fn_kwargs, to_dense_str, ) @@ -217,6 +218,7 @@ def test_get_fn_kwargs_helper_functions(): "class_groupings": to_dense_str(expected_data_module_kwargs["class_groupings"]), "optional_float": to_dense_str(expected_data_module_kwargs["optional_float"]), "bool_param": to_dense_str(expected_data_module_kwargs["bool_param"]), + "num_outputs": 10, } data_module_kwargs = get_data_module_fn_kwargs( config_module=config_module, config_space=config_space, cast_arguments=True @@ -226,3 +228,7 @@ def test_get_fn_kwargs_helper_functions(): config_module=config_module, config_space=config_space, cast_arguments=True ) assert model_kwargs == {"model_state_url": config_space["model_state_url"]} + metrics_kwargs = get_metrics_fn_kwargs( + config_module=config_module, config_space=config_space, cast_arguments=True + ) + assert metrics_kwargs == {"num_outputs": config_space["num_outputs"]} diff --git a/test/renate/renate_config_files/config.py b/test/renate/renate_config_files/config.py index 0ea96599..5b7a6900 100644 --- a/test/renate/renate_config_files/config.py +++ b/test/renate/renate_config_files/config.py @@ -37,5 +37,5 @@ def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: return torch.nn.CrossEntropyLoss(reduction="none") -def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} +def metrics_fn(num_outputs: int = 10) -> Dict: + return {"accuracy": Accuracy(task="multiclass", num_classes=num_outputs)} diff --git a/test/renate/renate_config_files/config_custom_optimizer.py b/test/renate/renate_config_files/config_custom_optimizer.py index eac8a25f..78da9244 100644 --- a/test/renate/renate_config_files/config_custom_optimizer.py +++ b/test/renate/renate_config_files/config_custom_optimizer.py @@ -45,4 +45,4 @@ def lr_scheduler_fn() -> Tuple[Callable[[Optimizer], _LRScheduler], str]: def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} + return {"accuracy": Accuracy(task="multiclass", num_classes=10)} diff --git a/test/renate/renate_config_files/config_scenario.py b/test/renate/renate_config_files/config_scenario.py index 77ec44cc..dac20391 100644 --- a/test/renate/renate_config_files/config_scenario.py +++ b/test/renate/renate_config_files/config_scenario.py @@ -40,4 +40,4 @@ def loss_fn(updater: Optional[str] = None) -> torch.nn.Module: def metrics_fn() -> Dict: - return {"accuracy": Accuracy()} + return {"accuracy": Accuracy(task="multiclass", num_classes=10)} diff --git a/test/renate/updaters/test_model_updater.py b/test/renate/updaters/test_model_updater.py index 9f978b57..8d76428d 100644 --- a/test/renate/updaters/test_model_updater.py +++ b/test/renate/updaters/test_model_updater.py @@ -5,6 +5,7 @@ import pytest import torch +from pytorch_lightning.utilities.seed import seed_everything from torchvision.transforms import Lambda from conftest import LEARNERS_USING_SIMPLE_UPDATER, LEARNER_KWARGS, check_learner_transforms @@ -73,6 +74,7 @@ def test_deterministic_updater(): def test_model_updater_with_early_stopping( use_val, early_stopping_enabled, metric_monitored, updater_type ): + seed_everything(0) model, train_dataset, val_dataset = pytest.helpers.get_renate_module_mlp_and_data( num_inputs=10, num_outputs=10, From 044e14b10a9b7002db571e01b7feb3100e484505 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 12 Jul 2023 11:06:11 +0200 Subject: [PATCH 45/89] Enable Downloading Large Files (#337) --- src/renate/utils/file.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/renate/utils/file.py b/src/renate/utils/file.py index 465c24ee..f7db7a68 100644 --- a/src/renate/utils/file.py +++ b/src/renate/utils/file.py @@ -305,8 +305,11 @@ def download_file( if src_bucket is None: if not os.path.exists(os.path.join(data_path, dataset_name)): os.makedirs(os.path.join(data_path, dataset_name)) - response = requests.get(os.path.join(url, file_name), allow_redirects=True) - open(os.path.join(data_path, dataset_name, file_name), "wb").write(response.content) + with requests.get(os.path.join(url, file_name), allow_redirects=True, stream=True) as r: + r.raise_for_status() + with open(os.path.join(data_path, dataset_name, file_name), "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) else: download_file_from_s3( src_bucket, From 8da203520e5493298a1ecba3ce3a865b9c98fb7a Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Fri, 21 Jul 2023 16:14:16 +0200 Subject: [PATCH 46/89] Adding a Datacollator to handle the wild time text datasets (#338) Co-authored-by: Prabhu Teja --- .../benchmark/datasets/wild_time_data.py | 5 ++ src/renate/updaters/experimental/er.py | 2 + src/renate/updaters/experimental/gdumb.py | 1 + src/renate/updaters/experimental/joint.py | 1 + .../updaters/experimental/offline_er.py | 1 + .../updaters/experimental/repeated_distill.py | 1 + src/renate/utils/hf_utils.py | 60 +++++++++++++++++ src/renate/utils/module.py | 1 + src/renate/utils/pytorch.py | 3 +- test/renate/utils/test_hf_utils.py | 66 +++++++++++++++++++ 10 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 src/renate/utils/hf_utils.py create mode 100644 test/renate/utils/test_hf_utils.py diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py index 2655ae85..85f0857d 100644 --- a/src/renate/benchmark/datasets/wild_time_data.py +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -8,6 +8,7 @@ from renate import defaults from renate.data.data_module import RenateDataModule from renate.utils.file import download_folder_from_s3 +from renate.utils.hf_utils import DataCollatorWithPaddingForWildTime class WildTimeDataModule(RenateDataModule): @@ -95,3 +96,7 @@ def setup(self) -> None: train_data = load_dataset(split="train", **kwargs) self._train_data, self._val_data = self._split_train_val_data(train_data) self._test_data = load_dataset(split="test", **kwargs) + if self._dataset_name in ["huffpost", "arxiv"]: + self._train_collate_fn = DataCollatorWithPaddingForWildTime(tokenizer=self._tokenizer) + self._val_collate_fn = DataCollatorWithPaddingForWildTime(tokenizer=self._tokenizer) + self._test_collate_fn = DataCollatorWithPaddingForWildTime(tokenizer=self._tokenizer) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index e32a06c1..c2427b83 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -110,6 +110,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) def on_train_start(self) -> None: @@ -204,6 +205,7 @@ def _set_memory_loader(self) -> None: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None: diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index 71379b3d..5c953706 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -88,6 +88,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) def training_step( diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index c21ca9fa..6ab5e52e 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -81,6 +81,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) def training_step( diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 49b03bda..13c893a1 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -87,6 +87,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) return CombinedLoader(loaders, mode="max_size_cycle") diff --git a/src/renate/updaters/experimental/repeated_distill.py b/src/renate/updaters/experimental/repeated_distill.py index 6a127f3c..dc3264f1 100644 --- a/src/renate/updaters/experimental/repeated_distill.py +++ b/src/renate/updaters/experimental/repeated_distill.py @@ -255,6 +255,7 @@ def train_dataloader(self) -> DataLoader: shuffle=True, generator=self._rng, pin_memory=True, + collate_fn=self._train_collate_fn, ) def on_model_update_end(self) -> None: diff --git a/src/renate/utils/hf_utils.py b/src/renate/utils/hf_utils.py new file mode 100644 index 00000000..09d66bb9 --- /dev/null +++ b/src/renate/utils/hf_utils.py @@ -0,0 +1,60 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +import torch +from torch.utils.data import default_collate +from transformers import BatchEncoding, DataCollatorWithPadding + + +@dataclass +class DataCollatorWithPaddingForWildTime(DataCollatorWithPadding): + """A data collator class that can handle wild time data (non-standard) batches. + + This adds to the `transformer` library's DataCollatorWithPadding. That data collator expects + data in a standard HF format. Wild time data format is slightly different: We get a + tuple of BatchEncoding (dict) and a class label. When being read from a buffer, an additional + metadata attribute is present. These cases are not handled by the orig data collator. + The code here only separates the input data into format original collator can handle and + undoes the data packing: see parts after super().__call__. + """ + + def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: + # first_element type determines if data sources + # known types : {features}, class (from wildtimedataset) + # : ({features}, class), metadata (from a replay buffer : val data?) + # : index, ({features}, class) (by _EnumeratedDataset in BaseER traindata) + # Note that {features} can be a dict or BatchEncoding. + first_element = features[0][0] + assert len(features[0]) == 2, "Input to collator should contain only two elements" + if isinstance(first_element, (BatchEncoding, dict)): + # This is normal dataset and not a buffer with metadata, do the default things + return self._collate_hf_class_tuples(features) + elif isinstance(first_element, tuple): + # this has metadata possibly. + collated, labels = self._collate_hf_class_tuples([elem[0] for elem in features]) + metadata_to_collate = [elem[1] for elem in features] + collated_metadata = default_collate(metadata_to_collate) + return (collated, labels), collated_metadata + elif isinstance(first_element, (torch.Tensor, int)): + ## this is EnumeratedDataset. + collated_indices = default_collate([elem[0] for elem in features]) + collated, labels = self._collate_hf_class_tuples([elem[1] for elem in features]) + return collated_indices, (collated, labels) + + else: + raise ValueError( + f"Unknown structure to collate. Got {first_element} of {type(first_element)}" + ) + + def _collate_hf_class_tuples( + self, features: Tuple[Union[BatchEncoding, Dict[str, Any]], Union[torch.Tensor, int]] + ): + to_be_collated = [] + for elem in features: + elem[0]["labels"] = elem[1] + to_be_collated.append(elem[0]) + collated = super().__call__(to_be_collated) + labels = collated.pop("labels") + return collated, labels diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 133f6830..6ac35735 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -62,6 +62,7 @@ def evaluate_and_record_results( update_results = evaluate( model=model, test_dataset=data_module.test_data(), + test_collate_fn=data_module.test_collate_fn(), task_id=task_id, batch_size=batch_size, transform=transform, diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 6f130236..970a08fe 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -6,6 +6,7 @@ import torch from torch.utils.data import Dataset, random_split +from transformers import BatchEncoding from renate import defaults from renate.types import NestedTensors @@ -72,7 +73,7 @@ def move_tensors_to_device(tensors: NestedTensors, device: torch.device) -> Nest The collection `tensors` can be a nested structure of tensors, tuples, lists, and dicts. """ - if isinstance(tensors, torch.Tensor): + if isinstance(tensors, (BatchEncoding, torch.Tensor)): return tensors.to(device) elif isinstance(tensors, tuple): return tuple(move_tensors_to_device(t, device) for t in tensors) diff --git a/test/renate/utils/test_hf_utils.py b/test/renate/utils/test_hf_utils.py new file mode 100644 index 00000000..7debdba6 --- /dev/null +++ b/test/renate/utils/test_hf_utils.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +from transformers import AutoTokenizer + +from renate.utils.hf_utils import BatchEncoding, DataCollatorWithPaddingForWildTime + + +@pytest.mark.parametrize( + "features, shape, error", + [ + ( + [ + ({"input_ids": [0, 1, 2]}, 0), + ({"input_ids": [0, 1, 2, 3, 4, 5]}, 1), + ], + torch.Size([2, 6]), + False, + ), + ( + [ + (({"input_ids": [1, 2, 4]}, 0), {}), + (({"input_ids": [4, 5, 6, 7, 8, 9]}, 1), {}), + ], + torch.Size([2, 6]), + False, + ), + ( + [ + (({"input_ids": [1, 2, 4]}, 0), {"logits": [0.1, 0.2, 0.3]}), + (({"input_ids": [4, 5, 6, 7, 8, 9]}, 1), {"logits": [0.3, 0.4, 0.5]}), + ], + torch.Size([2, 6]), + False, + ), + ( + [[1, ({"input_ids": [0, 1, 2]}, 0)], [2, ({"input_ids": [0, 1, 2, 3, 4, 5]}, 1)]], + torch.Size([2]), + False, + ), + ( + [ + ({"input_ids": [1, 2, 4]}, 0, {}), + ], + torch.Size([1, 3]), + True, + ), + ], +) +@pytest.mark.parametrize("tokenizer", [AutoTokenizer.from_pretrained("bert-base-cased")]) +def test_serialize_wildtime_collator(features, shape, error, tokenizer): + if not error: + collator = DataCollatorWithPaddingForWildTime(tokenizer) + out = collator(features)[0] + print(collator(features)) + if isinstance(out, (dict, BatchEncoding)): + assert out["input_ids"].shape == shape + elif isinstance(out, torch.Tensor): + assert out.shape == shape + else: + assert out[0]["input_ids"].shape == shape + else: + collator = DataCollatorWithPaddingForWildTime(tokenizer) + with pytest.raises(Exception): + out = collator(features)[0] From 7ee9903e052a94c62b2d07e8ddf1c2e969dc9548 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Jul 2023 12:23:12 +0200 Subject: [PATCH 47/89] Upgrade wild-time-data package (#346) --- doc/requirements.txt | 2 +- pyproject.toml | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index e16ca83b..18dc8184 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -9,4 +9,4 @@ sphinx-paramlinks==0.5.4 # Temporarily added avalanche_lib==0.3.1 -wild-time-data==0.1.0 +wild-time-data==0.1.1 diff --git a/pyproject.toml b/pyproject.toml index 94835147..fd7864ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,12 +22,12 @@ avalanche = [ "torch>=1.10.0, <1.12.2", ] benchmark = [ - "wild-time-data==0.1.0", + "wild-time-data==0.1.1", ] dev = [ "black==23.3.0", "avalanche_lib==0.3.1", - "wild-time-data==0.1.0", + "wild-time-data==0.1.1", "torch>=1.10.0, <1.12.2", # PyTest Dependencies "pytest==7.3.1", From ada173bfd84f7b064bda199257d84079432ae26e Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Jul 2023 16:07:38 +0200 Subject: [PATCH 48/89] Fix Scenario for CLEAR (#339) --- doc/benchmarking/renate_benchmarks.rst | 9 ++- .../benchmark/datasets/vision_datasets.py | 72 ++++++++----------- src/renate/benchmark/experiment_config.py | 24 ++++--- src/renate/benchmark/scenarios.py | 31 ++++---- test/dummy_datasets.py | 39 ---------- .../benchmark/test_experimentation_config.py | 39 +++++----- test/renate/benchmark/test_scenarios.py | 27 ++++--- 7 files changed, 95 insertions(+), 146 deletions(-) diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 402ad19a..fcb8c82a 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -184,13 +184,12 @@ The first part contains all instances with classes 1 and 2, the second with clas * - Scenario Name - Description - Settings - * - :py:class:`~renate.benchmark.scenarios.BenchmarkScenario` - - Used in combination only with CLEAR-10 or CLEAR-100. - - * :code:`num_tasks`: Number of data partitions. - * - :py:class:`~renate.benchmark.scenarios.WildTimeScenario` - - Used in combination only with Wild-Time datasets. This is not the scenario used in the paper. + * - :py:class:`~renate.benchmark.scenarios.TimeIncrementalScenario` + - Used in combination only with Wild-Time datasets or CLEAR. Data is presented time step by time step and the model is evaluated on test data up to the current time step. + This means that for the Wild-Time datasets, is a different scenario than in the original + Wild-Time data paper. - * :code:`num_tasks`: Number of data partitions. * - :py:class:`~renate.benchmark.scenarios.ClassIncrementalScenario` - Creates data partitions by splitting the data according to class labels. diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index c2b3ed46..215e978d 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -1,12 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import json import os from pathlib import Path from typing import List, Optional, Tuple, Union import torch import torchvision -from torch.utils.data import Dataset from torchvision import transforms from renate import defaults @@ -127,8 +127,6 @@ class TorchVisionDataModule(RenateDataModule): }, "FashionMNIST": {"mean": 0.2860405969887955, "std": 0.3530242445149223}, "MNIST": {"mean": 0.1306604762738429, "std": 0.30810780385646264}, - "CLEAR10": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, - "CLEAR100": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, } def __init__( @@ -196,16 +194,20 @@ class CLEARDataModule(RenateDataModule): Args: data_path: the path to the folder containing the dataset files. + time_step: Loads CLEAR dataset for this time step. Options: CLEAR10: [0,9], CLEAR100: [0,10] src_bucket: the name of the s3 bucket. If not provided, downloads the data from original source. src_object_name: the folder path in the s3 bucket. dataset_name: CLEAR dataset name, options are clear10 and clear100. - chunk_id: Used to define the CLEAR dataset splits. There are 10 splits in total with ids - from 0 to 9. val_size: Fraction of the training data to be used for validation. seed: Seed used to fix random number generation. """ + dataset_stats = { + "CLEAR10": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, + "CLEAR100": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]}, + } + md5s = { "clear10-train-image-only.zip": "5171f720810d60b471c308dee595d430", "clear100-train-image-only.zip": "ea85cdba9efcb3abf77eaab5554052c8", @@ -216,10 +218,10 @@ class CLEARDataModule(RenateDataModule): def __init__( self, data_path: Union[Path, str], + time_step: int = 0, src_bucket: Optional[str] = None, src_object_name: Optional[str] = None, dataset_name: str = "CLEAR10", - chunk_id: int = 0, val_size: float = defaults.VALIDATION_SIZE, seed: int = defaults.SEED, ): @@ -232,12 +234,8 @@ def __init__( ) self._dataset_name = dataset_name.lower() assert self._dataset_name in ["clear10", "clear100"] - self._verify_chunk_id(chunk_id) - self._chunk_id = chunk_id - - def _verify_chunk_id(self, chunk_id: int) -> None: - """Verify that the chunk_id is valid.""" - assert 0 <= chunk_id <= 9 + assert 0 <= time_step <= (9 if self._dataset_name == "clear10" else 10) + self.time_step = time_step def prepare_data(self) -> None: """Download CLEAR dataset with given dataset_name (clear10/clear100).""" @@ -257,42 +255,34 @@ def prepare_data(self) -> None: def setup(self) -> None: """Set up train, test and val datasets.""" - X, y = self._get_filepaths_and_labels(train=True, chunk_id=self._chunk_id) + time_step = self.time_step + 1 if self._dataset_name == "clear10" else self.time_step + X, y = self._get_filepaths_and_labels(train=True, time_step=time_step) train_data = ImageDataset(X, y, transform=transforms.ToTensor()) self._train_data, self._val_data = self._split_train_val_data(train_data) - self._test_data = [] - for i in range(10): - X, y = self._get_filepaths_and_labels(train=False, chunk_id=i) - self._test_data.append(ImageDataset(X, y, transform=transforms.ToTensor())) - - # TODO: This is a work-around to make CLEAR available as a data module (single test dataset) - # as well as a scenario (all test datasets) via BenchmarkScenario. Clean up! - def test_data(self) -> Dataset: - return self._test_data[self._chunk_id] + X, y = self._get_filepaths_and_labels(train=False, time_step=time_step) + self._test_data = ImageDataset(X, y, transform=transforms.ToTensor()) - def _get_filepaths_and_labels(self, train: bool, chunk_id: int) -> Tuple[List[str], List[int]]: + def _get_filepaths_and_labels(self, train: bool, time_step: int) -> Tuple[List[str], List[int]]: """Extracts all the filepaths and labels for a given chunk id and split.""" - data = [] - labels = [] path = os.path.join(self._data_path, self._dataset_name) - path = os.path.join(path, "train_image_only" if train else "test") # Load the class names and create a class mapping. The class names are in `class_names.txt` - label_encoding = {} - with open(os.path.join(path, "class_names.txt"), "r") as f: - class_names = [line.strip() for line in f.readlines() if line.strip() != "BACKGROUND"] + with open(os.path.join(path, "train_image_only", "class_names.txt"), "r") as f: + class_names = [line.strip() for line in f.readlines()] label_encoding = {name: cnt for cnt, name in enumerate(class_names)} - path = os.path.join(path, "labeled_images", str(chunk_id + 1)) - - # Go through all the subfolders in the path folder and search for all .jpg images - for root, _, files in os.walk(path): - for file in files: - if file.endswith(".jpg"): - folder = root.split("/")[-1] - if folder == "BACKGROUND": - continue - data.append(os.path.join(root, file)) - labels.append(label_encoding[folder]) + path = os.path.join(path, "train_image_only" if train else "test") + with open(os.path.join(path, "labeled_metadata.json"), "r") as f: + metadata = json.load(f) - return data, labels + image_paths = [] + labels = [] + for class_name, class_metadata_file in metadata[str(time_step)].items(): + label = label_encoding[class_name] + with open(os.path.join(path, class_metadata_file), "r") as f: + class_metadata = json.load(f) + for image_metadata in class_metadata.values(): + image_paths.append(os.path.join(path, image_metadata["IMG_PATH"])) + labels.append(label) + + return image_paths, labels diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 7164c556..a09d0b7f 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -31,7 +31,6 @@ ) from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer from renate.benchmark.scenarios import ( - BenchmarkScenario, ClassIncrementalScenario, FeatureSortingScenario, HueShiftScenario, @@ -39,7 +38,7 @@ ImageRotationScenario, PermutationScenario, Scenario, - WildTimeScenario, + TimeIncrementalScenario, ) from renate.data.data_module import RenateDataModule from renate.models import RenateModule @@ -189,10 +188,6 @@ def get_scenario( class_groupings=class_groupings, chunk_id=chunk_id, ) - if scenario_name == "BenchmarkScenario": - return BenchmarkScenario( - data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed - ) if scenario_name == "IIDScenario": return IIDScenario( data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed @@ -226,8 +221,8 @@ def get_scenario( chunk_id=chunk_id, seed=seed, ) - if scenario_name == "WildTimeScenario": - return WildTimeScenario( + if scenario_name == "TimeIncrementalScenario": + return TimeIncrementalScenario( data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed ) raise ValueError(f"Unknown scenario `{scenario_name}`.") @@ -291,6 +286,11 @@ def _get_normalize_transform(dataset_name): TorchVisionDataModule.dataset_stats[dataset_name]["mean"], TorchVisionDataModule.dataset_stats[dataset_name]["std"], ) + if dataset_name in ["CLEAR10", "CLEAR100"]: + return transforms.Normalize( + CLEARDataModule.dataset_stats[dataset_name]["mean"], + CLEARDataModule.dataset_stats[dataset_name]["std"], + ) def train_transform(dataset_name: str) -> Optional[Callable]: @@ -311,7 +311,9 @@ def train_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100"]: return transforms.Compose( [ - transforms.Resize(224), + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True + ), transforms.RandomCrop(224), _get_normalize_transform(dataset_name), ] @@ -331,7 +333,9 @@ def test_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100"]: return transforms.Compose( [ - transforms.Resize(224), + transforms.Resize( + 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True + ), transforms.CenterCrop(224), _get_normalize_transform(dataset_name), ] diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index 7a57b447..d0384e53 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -9,6 +9,7 @@ from torchvision.transforms import Lambda, RandomRotation, ToPILImage from renate import defaults +from renate.benchmark.datasets.vision_datasets import CLEARDataModule from renate.benchmark.datasets.wild_time_data import WildTimeDataModule from renate.data.data_module import RenateDataModule from renate.data.datasets import _TransformedDataset @@ -102,18 +103,6 @@ def _split_and_assign_train_and_val_data(self) -> None: self._val_data = randomly_split_data(val_data, proportions, self._seed)[self._chunk_id] -class BenchmarkScenario(Scenario): - """This is a scenario to concatenate test data of a data module, which by definition has - different chunks. - """ - - def setup(self) -> None: - super().setup() - self._train_data = self._data_module.train_data() - self._val_data = self._data_module.val_data() - self._test_data = self._data_module._test_data - - class ClassIncrementalScenario(Scenario): """A scenario that creates data chunks from data samples with specific classes from a data module. @@ -284,7 +273,7 @@ class _SortingScenario(Scenario): Randomness in the sorted order is induced by swapping the position of random pairs. Args: - data_module: The source RenateDataModule for the the user data. + data_module: The source RenateDataModule for the user data. num_tasks: The total number of expected tasks for experimentation. randomness: A value between 0 and 1. For a dataset with ``N`` data points, ``0.5 * N * randomness`` random pairs are swapped. @@ -399,11 +388,13 @@ def _get_scores(self, dataset: Dataset) -> List[float]: return scores -class WildTimeScenario(Scenario): - """Creating a time-incremental scenario for the Wild-Time datasets. +class TimeIncrementalScenario(Scenario): + """Creating a time-incremental scenario for specific datasets. - In contrast to the original work, data is presented time step by time step (no grouping) and - the test set is all data up to the current time step. + Supports the Wild Time datasets and CLEAR. + DataModules that want to use the TimeIncrementalScenario, need to have an attribute + ``time_step``. Setting this variable and then calling ``setup()`` should load the time-specific + datasets. Args: data_module: The source RenateDataModule for the user data. @@ -420,8 +411,10 @@ def __init__( seed: int = defaults.SEED, ) -> None: super().__init__(data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed) - if not isinstance(data_module, WildTimeDataModule): - raise ValueError("This scenario is only compatible with `WildTimeDataModule`.") + if not isinstance(data_module, (CLEARDataModule, WildTimeDataModule)): + raise ValueError( + "This scenario is only compatible with `CLEARDataModule` and `WildTimeDataModule`." + ) def setup(self) -> None: """Sets up the scenario.""" diff --git a/test/dummy_datasets.py b/test/dummy_datasets.py index 01bce786..091d57d2 100644 --- a/test/dummy_datasets.py +++ b/test/dummy_datasets.py @@ -69,42 +69,3 @@ def setup(self): self._train_data, self._val_data = self._split_train_val_data(train_data) self.X_test, self.y_test = self._get_random_data() self._test_data = DummyDataset(self.X_test, self.y_test, self._transform) - - -class DummyTorchVisionDataModuleWithChunks(DummyTorchVisionDataModule): - """This dataset has multiple chunks that can be queried through chunk ID.""" - - def __init__( - self, - transform=None, - val_size=defaults.VALIDATION_SIZE, - seed=defaults.SEED, - chunk_id=0, - num_chunks=1, - ): - super().__init__(transform=transform, val_size=val_size, seed=seed) - self.chunk_id = chunk_id - self.num_chunks = num_chunks - - def setup(self): - super().setup() - train_data = DummyDataset( - torch.split(self.X_train, 100 // self.num_chunks)[self.chunk_id], - torch.split(self.y_train, 100 // self.num_chunks)[self.chunk_id], - self._transform, - ) - self._train_data, self._val_data = self._split_train_val_data(train_data) - self._test_data = [] - for i in range(self.num_chunks): - self._test_data.append( - DummyDataset( - torch.split(self.X_test, 100 // self.num_chunks)[i], - torch.split(self.y_test, 100 // self.num_chunks)[i], - self._transform, - ) - ) - - # TODO: This is a work-around to make CLEAR available as a data module (single test dataset) - # as well as a scenario (all test datasets) via BenchmarkScenario. Clean up! - def test_data(self) -> Dataset: - return self._test_data[self.chunk_id] diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 2ad69e8a..ca0111b2 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -22,14 +22,13 @@ train_transform, ) from renate.benchmark.scenarios import ( - BenchmarkScenario, ClassIncrementalScenario, FeatureSortingScenario, HueShiftScenario, IIDScenario, ImageRotationScenario, PermutationScenario, - WildTimeScenario, + TimeIncrementalScenario, ) from renate.models.prediction_strategies import ICaRLClassificationStrategy @@ -176,7 +175,6 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): ImageRotationScenario, 3, ), - ("BenchmarkScenario", "CLEAR10", {"num_tasks": 5}, BenchmarkScenario, 5), ( "PermutationScenario", "MNIST", @@ -202,18 +200,19 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): HueShiftScenario, 3, ), + ("TimeIncrementalScenario", "CLEAR10", {"num_tasks": 5}, TimeIncrementalScenario, 5), ( - "WildTimeScenario", + "TimeIncrementalScenario", "arxiv", {"num_tasks": 3, "pretrained_model_name": "distilbert-base-uncased"}, - WildTimeScenario, + TimeIncrementalScenario, 3, ), ( - "WildTimeScenario", + "TimeIncrementalScenario", "fmow", {}, - WildTimeScenario, + TimeIncrementalScenario, 16, ), ), @@ -221,10 +220,10 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "class_incremental", "iid", "rotation", - "benchmark", "permutation", "feature_sorting", "hue_shift", + "time_with_clear", "wild_time_text_with_tokenizer", "wild_time_image_all_tasks", ], @@ -256,30 +255,34 @@ def test_data_module_fn( assert scenario._randomness == scenario_kwargs["randomness"] elif expected_scenario_class == HueShiftScenario: assert scenario._randomness == scenario_kwargs["randomness"] - elif expected_scenario_class == WildTimeScenario: + elif expected_scenario_class == TimeIncrementalScenario: if "pretrained_model_name" in scenario_kwargs: assert scenario._data_module._tokenizer is not None - else: + elif dataset_name not in ["CLEAR10", "CLEAR100"]: assert scenario._data_module._tokenizer is None assert scenario._num_tasks == expected_num_tasks @pytest.mark.parametrize( - "dataset_name,use_transforms", + "dataset_name,use_transforms,test_compose", ( - ("MNIST", False), - ("FashionMNIST", False), - ("CIFAR10", True), - ("CIFAR100", True), - ("hfd-rotten_tomatoes", False), + ("MNIST", False, False), + ("FashionMNIST", False, False), + ("CIFAR10", True, False), + ("CIFAR100", True, False), + ("CLEAR10", True, True), + ("hfd-rotten_tomatoes", False, False), ), ) -def test_transforms(dataset_name, use_transforms): +def test_transforms(dataset_name, use_transforms, test_compose): train_preprocessing = train_transform(dataset_name) test_preprocessing = experiment_config.test_transform(dataset_name) if use_transforms: assert isinstance(train_preprocessing, Compose) - assert isinstance(test_preprocessing, Normalize) + if test_compose: + assert isinstance(test_preprocessing, Compose) + else: + assert isinstance(test_preprocessing, Normalize) else: assert train_preprocessing is None assert test_preprocessing is None diff --git a/test/renate/benchmark/test_scenarios.py b/test/renate/benchmark/test_scenarios.py index aeb2d00f..a53035cc 100644 --- a/test/renate/benchmark/test_scenarios.py +++ b/test/renate/benchmark/test_scenarios.py @@ -7,15 +7,15 @@ import torch from torchvision.transforms.functional import rotate -from dummy_datasets import DummyTorchVisionDataModule, DummyTorchVisionDataModuleWithChunks +from dummy_datasets import DummyTorchVisionDataModule from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.benchmark.scenarios import ( - BenchmarkScenario, ClassIncrementalScenario, FeatureSortingScenario, IIDScenario, ImageRotationScenario, PermutationScenario, + TimeIncrementalScenario, ) from renate.utils.pytorch import randomly_split_data @@ -71,7 +71,7 @@ def test_class_incremental_scenario_class_grouping_error(): """Classes selected do not exist in data.""" scenario = ClassIncrementalScenario( data_module=DummyTorchVisionDataModule(val_size=0.3, seed=42), - class_groupings=[[0, 1, 3], [2, 200]], + class_groupings=((0, 1, 3), (2, 200)), chunk_id=0, ) scenario.prepare_data() @@ -79,6 +79,16 @@ def test_class_incremental_scenario_class_grouping_error(): scenario.setup() +def test_time_incremental_scenario_init_error(): + """Check that TimeIncrementalScenario raises Exception for unsupported DataModule.""" + with pytest.raises(ValueError, match=r"This scenario is only compatible with*"): + TimeIncrementalScenario( + data_module=DummyTorchVisionDataModule(), + num_tasks=2, + chunk_id=0, + ) + + def test_image_rotation_scenario(): data_module = DummyTorchVisionDataModule(val_size=0.3) degrees = [15, 75] @@ -167,17 +177,6 @@ def test_transforms_in_transform_scenarios_are_distinct(scenario_class, scenario assert not torch.all(torch.isclose(transform(x), transform2(x))) -def test_benchmark_scenario(): - data_module = DummyTorchVisionDataModuleWithChunks(num_chunks=3, val_size=0.2) - for chunk_id in range(3): - scenario = BenchmarkScenario(data_module=data_module, num_tasks=3, chunk_id=chunk_id) - scenario.prepare_data() - scenario.setup() - assert scenario.train_data() is not None - assert scenario.val_data() is not None - assert len(scenario.test_data()) == 3 - - def test_iid_scenario(): """Tests that the IID scenario creates a non-overlapping split.""" data_module = DummyTorchVisionDataModule(val_size=0.3) From 98dac269d8985a9e2d97cbf4467a8d8c1c22dc55 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Jul 2023 16:33:54 +0200 Subject: [PATCH 49/89] Fix CLS-ER Loss (#347) --- src/renate/updaters/learner_components/losses.py | 2 +- test/integration_tests/configs/suites/quick/cls-er.json | 4 ++-- test/integration_tests/configs/suites/quick/super-er.json | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/renate/updaters/learner_components/losses.py b/src/renate/updaters/learner_components/losses.py index 05dcc8b4..115d4c66 100644 --- a/src/renate/updaters/learner_components/losses.py +++ b/src/renate/updaters/learner_components/losses.py @@ -365,7 +365,7 @@ def _loss( (inputs_memory, targets_memory), _ = batch_memory with torch.no_grad(): outputs_plastic = self._plastic_model(inputs_memory) - outputs_stable = self._plastic_model(inputs_memory) + outputs_stable = self._stable_model(inputs_memory) probs_plastic = F.softmax(outputs_plastic, dim=-1) probs_stable = F.softmax(outputs_stable, dim=-1) label_mask = F.one_hot(targets_memory, num_classes=outputs_stable.shape[-1]) > 0 diff --git a/test/integration_tests/configs/suites/quick/cls-er.json b/test/integration_tests/configs/suites/quick/cls-er.json index f7882609..13498c6c 100644 --- a/test/integration_tests/configs/suites/quick/cls-er.json +++ b/test/integration_tests/configs/suites/quick/cls-er.json @@ -5,6 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "class-incremental-mlp-cls-er", - "expected_accuracy_linux": [[0.9858155846595764, 0.9740450382232666]], - "expected_accuracy_darwin": [[0.9858155846595764, 0.9755141735076904]] + "expected_accuracy_linux": [[0.9839243292808533, 0.9740450382232666]], + "expected_accuracy_darwin": [[0.9834515452384949, 0.9740450382232666]] } diff --git a/test/integration_tests/configs/suites/quick/super-er.json b/test/integration_tests/configs/suites/quick/super-er.json index 4fe2a4e0..3831303d 100644 --- a/test/integration_tests/configs/suites/quick/super-er.json +++ b/test/integration_tests/configs/suites/quick/super-er.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-super-er", - "expected_accuracy_linux": [[0.9330000281333923, 0.9144999980926514], [0.9434999823570251, 0.8974999785423279]], - "expected_accuracy_darwin": [[0.9390000104904175, 0.9089999794960022]] + "expected_accuracy_linux": [[0.9390000104904175, 0.9120000004768372], [0.9384999871253967, 0.9129999876022339]], + "expected_accuracy_darwin": [[0.9399999976158142, 0.9129999876022339]] } From 38d914376d9e8908e005c3f83137906b41cd7735 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 24 Jul 2023 16:56:54 +0200 Subject: [PATCH 50/89] Enable Offline-ER for NestedTensors (#336) --- .../updaters/experimental/offline_er.py | 18 ++++---- src/renate/utils/pytorch.py | 39 ++++++++++++++++- test/renate/utils/test_pytorch.py | 43 ++++++++++++++++++- 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 13c893a1..1baea452 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -17,7 +17,11 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater -from renate.utils.pytorch import move_tensors_to_device +from renate.utils.pytorch import ( + cat_nested_tensors, + get_length_nested_tensors, + move_tensors_to_device, +) class OfflineExperienceReplayLearner(ReplayLearner): @@ -97,9 +101,7 @@ def on_model_update_end(self) -> None: self._num_points_previous_tasks += self._num_points_current_task self._num_points_current_task = -1 - def training_step( - self, batch: Dict[str, Tuple[NestedTensors, torch.Tensor]], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" if self._loss_weight_new_data is None: alpha = self._num_points_current_task / ( @@ -108,13 +110,13 @@ def training_step( else: alpha = self._loss_weight_new_data inputs, targets = batch["current_task"] - device = inputs.device - batch_size_current = inputs.shape[0] + device = next(self.parameters()).device + batch_size_current = get_length_nested_tensors(inputs) batch_size_mem = 0 if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] - batch_size_mem = inputs_mem.shape[0] - inputs = torch.cat((inputs, inputs_mem), 0) + batch_size_mem = get_length_nested_tensors(inputs_mem) + inputs = cat_nested_tensors((inputs, inputs_mem), 0) targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) loss = self._loss_fn(outputs, targets) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 970a08fe..8f78e037 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch from torch.utils.data import Dataset, random_split @@ -88,3 +88,40 @@ def move_tensors_to_device(tensors: NestedTensors, device: torch.device) -> Nest "Expected `tensors` to be a nested structure of tensors, tuples, list and dict; " f"discovered {type(tensors)}." ) + + +def get_length_nested_tensors(batch: NestedTensors) -> torch.Size: + """Given a NestedTensor, return its length. + + Assumes that the first axis in each element is the same. + """ + if isinstance(batch, torch.Tensor): + return batch.shape[0] + if isinstance(batch, tuple): + return batch[0].shape[0] + if isinstance(batch, dict): + return batch[next(iter(batch.keys()))].shape[0] + + +def cat_nested_tensors( + nested_tensors: Union[Tuple[NestedTensors], List[NestedTensors]], axis: int = 0 +) -> NestedTensors: + """Concatenates the two NestedTensors. + + Equivalent of PyTorch's ``cat`` function for ``NestedTensors``. + + Args: + nested_tensors: Tensors to be concatenated. + axis: Concatenation axis. + """ + if isinstance(nested_tensors[0], torch.Tensor): + return torch.cat(nested_tensors, axis) + if isinstance(nested_tensors[0], tuple): + return tuple( + cat_nested_tensors(nested_tensor, axis) for nested_tensor in zip(*nested_tensors) + ) + if isinstance(nested_tensors[0], dict): + return { + key: cat_nested_tensors([nested_tensor[key] for nested_tensor in nested_tensors], axis) + for key in nested_tensors[0] + } diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index 6ae2fdde..b247976a 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -6,7 +6,7 @@ from torch.utils.data import TensorDataset from renate.utils import pytorch -from renate.utils.pytorch import randomly_split_data +from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors, randomly_split_data @pytest.mark.parametrize("model", [torchvision.models.resnet18(pretrained=True)]) @@ -61,3 +61,44 @@ def test_random_splitting_sample_split_with_same_random_seed(): for i in range(5): assert torch.equal(d_1_split_1[i][0], d_2_split_1[i][0]) assert torch.equal(d_1_split_2[i][0], d_2_split_2[i][0]) + + +def test_get_length_nested_tensors(): + expected_batch_size = 10 + t = torch.zeros(expected_batch_size) + assert get_length_nested_tensors(t) == expected_batch_size + tuple_tensor = (t, t) + assert get_length_nested_tensors(tuple_tensor) == expected_batch_size + dict_tensor = {"k1": t, "k2": t} + assert get_length_nested_tensors(dict_tensor) == expected_batch_size + + +def test_cat_nested_tensors(): + tensor_dim = 2 + first_dim_ones = 8 + zeros = torch.zeros((2, tensor_dim)) + ones = torch.ones((first_dim_ones, tensor_dim)) + result = cat_nested_tensors((zeros, ones)) + assert get_length_nested_tensors(result) == 10 + assert result.mean() == 0.8 + tuple_tensor = (zeros, ones) + result = cat_nested_tensors((tuple_tensor, tuple_tensor)) + assert get_length_nested_tensors(result) == 4 + assert result[0].sum() == 0 + assert result[1].sum() == 2 * first_dim_ones * tensor_dim + dict_tensor = {"zeros": zeros, "ones": ones} + result = cat_nested_tensors((dict_tensor, dict_tensor)) + assert get_length_nested_tensors(result) == 4 + assert result["zeros"].sum() == 0 + assert result["ones"].sum() == 2 * first_dim_ones * tensor_dim + + +def test_cat_nested_tensors_wrong_shape(): + tensor1 = torch.zeros((2, 2)) + tensor2 = torch.zeros((2, 3)) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors((tensor1, tensor2)) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors(((tensor1, tensor1), (tensor1, tensor2))) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors(({"k1": tensor1, "k2": tensor1}, {"k1": tensor1, "k2": tensor2})) From 740cdf69cd4d4de14c55ddda2ae13ca907df7728 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Wed, 26 Jul 2023 17:03:14 +0200 Subject: [PATCH 51/89] Make number of epochs "finetuning-equivalent" (#344) --- src/renate/cli/parsing_functions.py | 3 ++- src/renate/cli/run_experiment_with_scenario.py | 2 +- src/renate/training/training.py | 4 +++- src/renate/updaters/model_updater.py | 8 +++++++- test/integration_tests/configs/suites/quick/joint.json | 4 ++-- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 7b713c51..427c0bc5 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -230,7 +230,8 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: "max_epochs": { "type": int, "default": defaults.MAX_EPOCHS, - "help": f"Number of epochs trained at most. Default: {defaults.MAX_EPOCHS}", + "help": "Maximum number of (finetuning-equivalent) epochs. " + f"Default: {defaults.MAX_EPOCHS}", "argument_group": OPTIONAL_ARGS_GROUP, }, "task_id": { diff --git a/src/renate/cli/run_experiment_with_scenario.py b/src/renate/cli/run_experiment_with_scenario.py index f6b0a3bf..2e35439e 100644 --- a/src/renate/cli/run_experiment_with_scenario.py +++ b/src/renate/cli/run_experiment_with_scenario.py @@ -100,7 +100,7 @@ def run(self): "--max_epochs", type=int, default=defaults.MAX_EPOCHS, - help=f"Number of epochs trained at most. Default: {defaults.MAX_EPOCHS}", + help=f"Maximum number of (finetuning-equiv.) epochs. Default: {defaults.MAX_EPOCHS}", ) argument_group.add_argument( "--seed", diff --git a/src/renate/training/training.py b/src/renate/training/training.py index f2d876b9..b528fba7 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -105,7 +105,9 @@ def run_training_job( metric: Name of metric to optimize. backend: Whether to run jobs locally (`local`) or on SageMaker (`sagemaker`). updater: Updater used for model update. - max_epochs: Maximum number of epochs the model is trained. + max_epochs: The maximum number of epochs used to train the model. For comparability between + methods, epochs are interpreted as "finetuning-equivalent". That is, one epoch is + defined as `len(current_task_dataset) / batch_size` update steps. task_id: Unique identifier for the current task. chunk_id: Unique identifier for the current data chunk. input_state_url: Path to the Renate model state. diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index 451440b0..8f4db296 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -213,7 +213,9 @@ class ModelUpdater(abc.ABC): state available) or replace current arguments of the learner. input_state_folder: Folder used by Renate to store files for current state. output_state_folder: Folder used by Renate to store files for next state. - max_epochs: The maximum number of epochs used to train the model. + max_epochs: The maximum number of epochs used to train the model. For comparability between + methods, epochs are interpreted as "finetuning-equivalent". That is, one epoch is + defined as `len(current_task_dataset) / batch_size` update steps. train_transform: The transformation applied during training. train_target_transform: The target transformation applied during testing. test_transform: The transformation at test time. @@ -408,10 +410,14 @@ def _fit_learner( ) strategy = create_strategy(self._devices, self._strategy) + # Finetuning-equivalent epochs. + num_batches = len(learner._train_dataset) // learner._batch_size + num_batches += min(len(learner._train_dataset) % learner._batch_size, 1) trainer = Trainer( accelerator=self._accelerator, devices=self._devices, max_epochs=self._max_epochs, + limit_train_batches=num_batches, callbacks=callbacks, logger=self._logger, enable_progress_bar=False, diff --git a/test/integration_tests/configs/suites/quick/joint.json b/test/integration_tests/configs/suites/quick/joint.json index 8cb887cd..2de798c4 100644 --- a/test/integration_tests/configs/suites/quick/joint.json +++ b/test/integration_tests/configs/suites/quick/joint.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "iid-mlp-joint", - "expected_accuracy_linux": [[0.8639000058174133, 0.8639000058174133], [0.8618000149726868, 0.8618000149726868]], - "expected_accuracy_darwin": [[0.859499990940094, 0.859499990940094]] + "expected_accuracy_linux": [[0.8495000004768372, 0.8495000004768372], [0.8427000045776367, 0.8427000045776367]], + "expected_accuracy_darwin": [[0.84170001745224, 0.84170001745224]] } From eacf1813ba1983b432269e63c31ec7752ec041f8 Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Thu, 27 Jul 2023 10:30:33 +0200 Subject: [PATCH 52/89] Fix loss weighting in OfflineER (#355) --- .../updaters/experimental/offline_er.py | 29 ++++++------------- .../configs/suites/quick/offline-er.json | 4 +-- 2 files changed, 11 insertions(+), 22 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 1baea452..2666bd7c 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -17,11 +17,7 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater -from renate.utils.pytorch import ( - cat_nested_tensors, - get_length_nested_tensors, - move_tensors_to_device, -) +from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors class OfflineExperienceReplayLearner(ReplayLearner): @@ -109,31 +105,24 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) ) else: alpha = self._loss_weight_new_data + alpha = torch.tensor(alpha, device=next(self.parameters()).device) inputs, targets = batch["current_task"] - device = next(self.parameters()).device batch_size_current = get_length_nested_tensors(inputs) - batch_size_mem = 0 if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] - batch_size_mem = get_length_nested_tensors(inputs_mem) inputs = cat_nested_tensors((inputs, inputs_mem), 0) targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) loss = self._loss_fn(outputs, targets) if "memory" in batch: - weights = torch.Tensor( - [ - [alpha for _ in range(batch_size_current)] - + [(1 - alpha) for _ in range(batch_size_mem)] - ] - ) - self._loss_collections["train_losses"]["memory_loss"](loss[batch_size_current:].mean()) - self._loss_collections["train_losses"]["base_loss"](loss[:batch_size_current].mean()) - weights = move_tensors_to_device(weights, device=device) - loss = weights / weights.mean() * loss + loss_current = loss[:batch_size_current].mean() + loss_memory = loss[batch_size_current:].mean() + self._loss_collections["train_losses"]["base_loss"](loss_current) + self._loss_collections["train_losses"]["memory_loss"](loss_memory) + loss = alpha * loss_current + (1.0 - alpha) * loss_memory else: - self._loss_collections["train_losses"]["base_loss"](loss[:batch_size_current].mean()) - loss = loss.mean() + loss = loss.mean() + self._loss_collections["train_losses"]["base_loss"](loss) self._update_metrics(outputs, targets, "train") return {"loss": loss} diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index c0b04935..2c62856a 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.765500009059906, 0.4020000100135803], [0.7574999928474426, 0.4424999952316284]], - "expected_accuracy_darwin": [[0.7549999952316284, 0.45249998569488525]] + "expected_accuracy_linux": [[0.7319999933242798, 0.4699999988079071], [0.7515000104904175, 0.49300000071525574]], + "expected_accuracy_darwin": [[0.7300000190734863, 0.5350000262260437]] } From 2e071d8fa513d30e2d6d121da6f2e80a27b16680 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 27 Jul 2023 15:46:20 +0200 Subject: [PATCH 53/89] Fixing Bug with HPO (#345) --- src/renate/updaters/experimental/er.py | 142 ++++---- .../updaters/learner_components/component.py | 34 +- .../updaters/learner_components/losses.py | 188 ++-------- .../learner_components/reinitialization.py | 25 +- test/conftest.py | 8 - .../configs/suites/quick/cls-er.json | 2 +- .../configs/suites/quick/super-er.json | 4 +- test/renate/updaters/experimental/test_er.py | 323 ++++++++++++------ 8 files changed, 341 insertions(+), 385 deletions(-) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index c2427b83..a0d4e331 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import torch -import torch.nn as nn import torchmetrics from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -19,6 +18,7 @@ from renate.models import RenateModule from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner +from renate.updaters.learner_components.component import Component from renate.updaters.learner_components.losses import ( WeightedCLSLossComponent, WeightedCustomLossComponent, @@ -52,7 +52,7 @@ class BaseExperienceReplayLearner(ReplayLearner, abc.ABC): def __init__( self, - components: nn.ModuleDict, + components: Dict[str, Component], loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -215,12 +215,24 @@ def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) - component.on_train_batch_end(model=self._model) @abc.abstractmethod - def components(self, **kwargs) -> nn.ModuleDict: + def components(self, **kwargs) -> Dict[str, Component]: """Returns the components of the learner. This is a user-defined function that should return a dictionary of components. """ + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Load states of components.""" + super().on_load_checkpoint(checkpoint) + for component in self._components.values(): + component.on_load_checkpoint(checkpoint) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Save states of components.""" + super().on_save_checkpoint(checkpoint) + for component in self._components.values(): + component.on_save_checkpoint(checkpoint) + class ExperienceReplayLearner(BaseExperienceReplayLearner): """This is the version of experience replay proposed in @@ -238,14 +250,12 @@ def __init__(self, alpha: float = defaults.ER_ALPHA, **kwargs) -> None: def components( self, loss_fn: Optional[torch.nn.Module] = None, alpha: float = defaults.ER_ALPHA - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - "memory_loss": WeightedCustomLossComponent( - loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True - ) - } - ) + ) -> Dict[str, Component]: + return { + "memory_loss": WeightedCustomLossComponent( + loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True + ) + } class DarkExperienceReplayLearner(ExperienceReplayLearner): @@ -272,7 +282,7 @@ def components( loss_fn: Optional[torch.nn.Module] = None, alpha: float = defaults.DER_ALPHA, beta: float = defaults.DER_BETA, - ) -> nn.ModuleDict: + ) -> Dict[str, Component]: components = super().components(loss_fn=loss_fn, alpha=beta) components.update( { @@ -316,17 +326,15 @@ def components( alpha: float = defaults.POD_ALPHA, distillation_type: str = defaults.POD_DISTILLATION_TYPE, normalize: bool = defaults.POD_NORMALIZE, - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - "pod_loss": WeightedPooledOutputDistillationLossComponent( - weight=alpha, - sample_new_memory_batch=True, - distillation_type=distillation_type, - normalize=normalize, - ) - } - ) + ) -> Dict[str, Component]: + return { + "pod_loss": WeightedPooledOutputDistillationLossComponent( + weight=alpha, + sample_new_memory_batch=True, + distillation_type=distillation_type, + normalize=normalize, + ) + } class CLSExperienceReplayLearner(BaseExperienceReplayLearner): @@ -381,23 +389,21 @@ def components( stable_model_update_weight: float = defaults.CLS_STABLE_MODEL_UPDATE_WEIGHT, plastic_model_update_probability: float = defaults.CLS_PLASTIC_MODEL_UPDATE_PROBABILITY, stable_model_update_probability: float = defaults.CLS_STABLE_MODEL_UPDATE_PROBABILITY, - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - "memory_loss": WeightedCustomLossComponent( - loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True - ), - "cls_loss": WeightedCLSLossComponent( - weight=beta, - sample_new_memory_batch=False, - model=model, - plastic_model_update_weight=plastic_model_update_weight, - stable_model_update_weight=stable_model_update_weight, - plastic_model_update_probability=plastic_model_update_probability, - stable_model_update_probability=stable_model_update_probability, - ), - } - ) + ) -> Dict[str, Component]: + return { + "memory_loss": WeightedCustomLossComponent( + loss_fn=loss_fn, weight=alpha, sample_new_memory_batch=True + ), + "cls_loss": WeightedCLSLossComponent( + weight=beta, + sample_new_memory_batch=False, + model=model, + plastic_model_update_weight=plastic_model_update_weight, + stable_model_update_weight=stable_model_update_weight, + plastic_model_update_probability=plastic_model_update_probability, + stable_model_update_probability=stable_model_update_probability, + ), + } class SuperExperienceReplayLearner(BaseExperienceReplayLearner): @@ -482,35 +488,33 @@ def components( pod_alpha: float = defaults.SER_POD_ALPHA, pod_distillation_type: str = defaults.SER_POD_DISTILLATION_TYPE, pod_normalize: bool = defaults.SER_POD_NORMALIZE, - ) -> nn.ModuleDict: - return nn.ModuleDict( - { - "mse_loss": WeightedMeanSquaredErrorLossComponent( - weight=der_alpha, sample_new_memory_batch=True - ), - "memory_loss": WeightedCustomLossComponent( - loss_fn=loss_fn, weight=der_beta, sample_new_memory_batch=True - ), - "cls_loss": WeightedCLSLossComponent( - weight=cls_alpha, - sample_new_memory_batch=False, - model=model, - stable_model_update_weight=cls_stable_model_update_weight, - plastic_model_update_weight=cls_plastic_model_update_weight, - stable_model_update_probability=cls_stable_model_update_probability, - plastic_model_update_probability=cls_plastic_model_update_probability, - ), - "shrink_perturb": ShrinkAndPerturbReinitializationComponent( - shrink_factor=sp_shrink_factor, sigma=sp_sigma - ), - "pod_loss": WeightedPooledOutputDistillationLossComponent( - weight=pod_alpha, - sample_new_memory_batch=True, - distillation_type=pod_distillation_type, - normalize=pod_normalize, - ), - } - ) + ) -> Dict[str, Component]: + return { + "mse_loss": WeightedMeanSquaredErrorLossComponent( + weight=der_alpha, sample_new_memory_batch=True + ), + "memory_loss": WeightedCustomLossComponent( + loss_fn=loss_fn, weight=der_beta, sample_new_memory_batch=True + ), + "cls_loss": WeightedCLSLossComponent( + weight=cls_alpha, + sample_new_memory_batch=False, + model=model, + stable_model_update_weight=cls_stable_model_update_weight, + plastic_model_update_weight=cls_plastic_model_update_weight, + stable_model_update_probability=cls_stable_model_update_probability, + plastic_model_update_probability=cls_plastic_model_update_probability, + ), + "shrink_perturb": ShrinkAndPerturbReinitializationComponent( + shrink_factor=sp_shrink_factor, sigma=sp_sigma + ), + "pod_loss": WeightedPooledOutputDistillationLossComponent( + weight=pod_alpha, + sample_new_memory_batch=True, + distillation_type=pod_distillation_type, + normalize=pod_normalize, + ), + } class ExperienceReplayModelUpdater(SingleTrainingLoopUpdater): diff --git a/src/renate/updaters/learner_components/component.py b/src/renate/updaters/learner_components/component.py index a7a173be..e6b9df89 100644 --- a/src/renate/updaters/learner_components/component.py +++ b/src/renate/updaters/learner_components/component.py @@ -4,24 +4,28 @@ from typing import Any, Dict, List, Optional, Tuple import torch -import torch.nn as nn from renate.models import RenateModule from renate.types import NestedTensors -class Component(nn.Module, abc.ABC): +class Component(abc.ABC): """The abstract class implementing a Component, usable in the BaseExperienceReplayLearner. This is an abstract class from which each other component e.g. additional regularising loss or a module updater should inherit from. - The components should be a modular and independent to an extend where they can be composed + The components should be a modular and independent to an extent where they can be composed together in an ordered list to be deployed in the BaseExperienceReplayLearner. + + Args: + weight: A scaling coefficient which should scale the loss which gets returned. + sample_new_memory_batch: Whether a new batch of data should be sampled from the memory + buffer when the loss is calculated. """ - def __init__(self, **kwargs: Any) -> None: - super().__init__() - self._register_parameters(**kwargs) + def __init__(self, weight: float = 0, sample_new_memory_batch: bool = False) -> None: + self.weight = weight + self.sample_new_memory_batch = sample_new_memory_batch self._verify_attributes() def loss( @@ -57,20 +61,14 @@ def on_train_batch_end(self, model: RenateModule) -> None: """ pass - @property - def weight(self) -> torch.Tensor: - """The weight of the loss component.""" - return self._weight - - @property - def sample_new_memory_batch(self) -> torch.Tensor: - """Whether to sample a new memory batch or not.""" - return self._sample_new_memory_batch - def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" pass - def _register_parameters(self) -> None: - """Function to register parameters of the component.""" + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Load relevant information from checkpoint.""" + pass + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Add relevant information to checkpoint.""" pass diff --git a/src/renate/updaters/learner_components/losses.py b/src/renate/updaters/learner_components/losses.py index 115d4c66..1552e41e 100644 --- a/src/renate/updaters/learner_components/losses.py +++ b/src/renate/updaters/learner_components/losses.py @@ -16,34 +16,12 @@ class WeightedLossComponent(Component, ABC): """The abstract class implementing a weighted loss function. This is an abstract class from which each other loss should inherit from. - - Args: - weight: A scaling coefficient which should scale the loss which gets returned. - sample_new_memory_batch: Whether a new batch of data should be sampled from the memory - buffer when the loss is calculated. """ - def __init__(self, weight: float, sample_new_memory_batch: bool, **kwargs: Any) -> None: - super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch, **kwargs) - def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" super()._verify_attributes() - assert self._weight >= 0, "Weight must be larger than 0." - - def _register_parameters(self, weight: float, sample_new_memory_batch: bool) -> None: - """Register parameters of the loss.""" - super()._register_parameters() - self.register_buffer("_weight", torch.tensor(weight, dtype=torch.float)) - self.register_buffer( - "_sample_new_memory_batch", torch.tensor(sample_new_memory_batch, dtype=torch.bool) - ) - - def set_weight(self, weight: float) -> None: - self._weight.data = torch.tensor( - weight, dtype=self._weight.dtype, device=self._weight.device - ) - self._verify_attributes() + assert self.weight >= 0, "Weight must be larger than 0." def loss( self, @@ -79,10 +57,8 @@ class WeightedCustomLossComponent(WeightedLossComponent): buffer when the loss is calculated. """ - def __init__( - self, loss_fn: Callable, weight: float, sample_new_memory_batch: bool, **kwargs: Any - ) -> None: - super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch, **kwargs) + def __init__(self, loss_fn: Callable, weight: float, sample_new_memory_batch: bool) -> None: + super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) self._loss_fn = loss_fn def _loss( @@ -145,41 +121,8 @@ def __init__( normalize: bool = True, ) -> None: self._distillation_type = distillation_type - super().__init__( - weight=weight, - sample_new_memory_batch=sample_new_memory_batch, - normalize=normalize, - ) - - def _register_parameters( - self, weight: float, sample_new_memory_batch: bool, normalize: bool - ) -> None: - """Register parameters of the loss.""" - super()._register_parameters(weight=weight, sample_new_memory_batch=sample_new_memory_batch) - self.register_buffer("_normalize", torch.tensor(normalize, dtype=torch.bool)) - - def _save_to_state_dict( - self, destination: Dict[str, Any], prefix: str, keep_vars: bool - ) -> None: - """Save attributes to state dict.""" - super()._save_to_state_dict(destination, prefix, keep_vars) - destination[prefix + "distillation_type"] = self._distillation_type - - def _load_from_state_dict( - self, - state_dict: Dict[str, Any], - prefix: str, - local_metadata: Dict[str, Any], - strict: bool, - missing_keys: List[str], - unexpected_keys: List[str], - error_msgs: List[str], - ) -> None: - """Load attributes from state dict.""" - self._distillation_type = state_dict.pop(prefix + "distillation_type") - super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ) + super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) + self._normalize = normalize def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" @@ -208,7 +151,7 @@ def _pod(self, features: torch.Tensor, features_memory: torch.Tensor) -> torch.T features = features.pow(2) features_memory = features_memory.pow(2) - if self._distillation_type == "channels": + if self._distillation_type == "channel": features, features_memory = self._sum_reshape(features, 1), self._sum_reshape( features_memory, 1 ) @@ -239,14 +182,6 @@ def _pod(self, features: torch.Tensor, features_memory: torch.Tensor) -> torch.T return torch.frobenius_norm(features - features_memory, dim=-1).mean(dim=0) - def set_distillation_type(self, distillation_type: str) -> None: - self._distillation_type = distillation_type - self._verify_attributes() - - def set_normalize(self, normalize: bool) -> None: - self._normalize = torch.tensor(normalize, dtype=torch.bool, device=self._normalize.device) - self._verify_attributes() - def _loss( self, outputs_memory: torch.Tensor, @@ -256,7 +191,7 @@ def _loss( """Compute the pooled output with respect to current and cached intermediate outputs from memory. """ - loss = 0.0 + loss = torch.tensor(0.0) _, meta_data = batch_memory for n in range(len(intermediate_representation_memory)): features = intermediate_representation_memory[n] @@ -298,53 +233,18 @@ def __init__( stable_model_update_probability: float, plastic_model_update_probability: float, ) -> None: - super().__init__( - weight=weight, - sample_new_memory_batch=sample_new_memory_batch, - stable_model_update_weight=stable_model_update_weight, - plastic_model_update_weight=plastic_model_update_weight, - stable_model_update_probability=stable_model_update_probability, - plastic_model_update_probability=plastic_model_update_probability, - iteration=0, - ) + self._stable_model_update_weight = stable_model_update_weight + self._stable_model_update_weight = stable_model_update_weight + self._plastic_model_update_weight = plastic_model_update_weight + self._stable_model_update_probability = stable_model_update_probability + self._plastic_model_update_probability = plastic_model_update_probability + self._iteration = 0 + super().__init__(weight=weight, sample_new_memory_batch=sample_new_memory_batch) self._plastic_model: RenateModule = copy.deepcopy(model) self._stable_model: RenateModule = copy.deepcopy(model) self._plastic_model.deregister_hooks() self._stable_model.deregister_hooks() - def _register_parameters( - self, - weight: float, - sample_new_memory_batch: bool, - stable_model_update_weight: float, - plastic_model_update_weight: float, - stable_model_update_probability: float, - plastic_model_update_probability: float, - iteration: int, - ) -> None: - """Register the parameters of the loss component.""" - super()._register_parameters( - weight=weight, - sample_new_memory_batch=sample_new_memory_batch, - ) - self.register_buffer( - "_stable_model_update_weight", - torch.tensor(stable_model_update_weight, dtype=torch.float32), - ) - self.register_buffer( - "_plastic_model_update_weight", - torch.tensor(plastic_model_update_weight, dtype=torch.float32), - ) - self.register_buffer( - "_stable_model_update_probability", - torch.tensor(stable_model_update_probability, dtype=torch.float32), - ) - self.register_buffer( - "_plastic_model_update_probability", - torch.tensor(plastic_model_update_probability, dtype=torch.float32), - ) - self.register_buffer("_iteration", torch.tensor(iteration, dtype=torch.int64)) - def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" super()._verify_attributes() @@ -376,7 +276,7 @@ def _loss( @torch.no_grad() def _update_model_variables( - self, model: RenateModule, original_model: RenateModule, weight: torch.Tensor + self, model: RenateModule, original_model: RenateModule, weight: float ) -> None: """Performs exponential moving average on the stored model copies. @@ -384,9 +284,7 @@ def _update_model_variables( model: Whether the plastic or the stable model is updated. weight: The minimum weight used in the exponential moving average to update the model. """ - alpha = min( - 1.0 - torch.tensor(1.0, device=self._iteration.device) / (self._iteration + 1), weight - ) + alpha = min(1.0 - 1.0 / (self._iteration + 1), weight) for ema_p, p in zip(model.parameters(), original_model.parameters()): ema_p.data.mul_(alpha).add_(p.data, alpha=1 - alpha) @@ -394,50 +292,26 @@ def on_train_batch_end(self, model: RenateModule) -> None: """Updates the model copies with the current weights, given the specified probabilities of update, and increments iteration counter.""" self._iteration += 1 - if ( - torch.rand(1, device=self._plastic_model_update_probability.device) - < self._plastic_model_update_probability - ): + if torch.rand(1) < self._plastic_model_update_probability: self._update_model_variables( self._plastic_model, model, self._plastic_model_update_weight ) - if ( - torch.rand(1, device=self._stable_model_update_probability.device) - < self._stable_model_update_probability - ): + if torch.rand(1) < self._stable_model_update_probability: self._update_model_variables( self._stable_model, model, self._stable_model_update_weight ) - def set_stable_model_update_weight(self, stable_model_update_weight: float) -> None: - self._stable_model_update_weight.data = torch.tensor( - stable_model_update_weight, - dtype=self._stable_model_update_weight.dtype, - device=self._stable_model_update_weight.device, - ) - self._verify_attributes() - - def set_plastic_model_update_weight(self, plastic_model_update_weight: float) -> None: - self._plastic_model_update_weight.data = torch.tensor( - plastic_model_update_weight, - dtype=self._plastic_model_update_weight.dtype, - device=self._plastic_model_update_weight.device, - ) - self._verify_attributes() - - def set_stable_model_update_probability(self, stable_model_update_probability: float) -> None: - self._stable_model_update_probability.data = torch.tensor( - stable_model_update_probability, - dtype=self._stable_model_update_probability.dtype, - device=self._stable_model_update_probability.device, - ) - self._verify_attributes() - - def set_plastic_model_update_probability(self, plastic_model_update_probability: float) -> None: - self._plastic_model_update_probability.data = torch.tensor( - plastic_model_update_probability, - dtype=self._plastic_model_update_probability.dtype, - device=self._plastic_model_update_probability.device, - ) - self._verify_attributes() + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Load relevant information from checkpoint.""" + super().on_load_checkpoint(checkpoint) + self._plastic_model.load_state_dict(checkpoint["component-cls-plastic-model"]) + self._stable_model.load_state_dict(checkpoint["component-cls-stable-model"]) + self._iteration = checkpoint["component-cls-iteration"] + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Add plastic and stable model to checkpoint.""" + super().on_save_checkpoint(checkpoint) + checkpoint["component-cls-plastic-model"] = self._plastic_model.state_dict() + checkpoint["component-cls-stable-model"] = self._stable_model.state_dict() + checkpoint["component-cls-iteration"] = self._iteration diff --git a/src/renate/updaters/learner_components/reinitialization.py b/src/renate/updaters/learner_components/reinitialization.py index d9628312..5c052654 100644 --- a/src/renate/updaters/learner_components/reinitialization.py +++ b/src/renate/updaters/learner_components/reinitialization.py @@ -30,18 +30,9 @@ class ShrinkAndPerturbReinitializationComponent(Component): """ def __init__(self, shrink_factor: float, sigma: float) -> None: - super().__init__( - shrink_factor=shrink_factor, - sigma=sigma, - ) - - def _register_parameters(self, shrink_factor: float, sigma: float) -> None: - """Register parameters of the loss.""" - super()._register_parameters() - self.register_buffer("_shrink_factor", torch.tensor(shrink_factor, dtype=torch.float)) - self.register_buffer("_sigma", torch.tensor(sigma, dtype=torch.float)) - self.register_buffer("_weight", torch.tensor(0.0, dtype=torch.float)) - self.register_buffer("_sample_new_memory_batch", torch.tensor(False, dtype=torch.bool)) + self._shrink_factor = shrink_factor + self._sigma = sigma + super().__init__() def _verify_attributes(self) -> None: """Verify if attributes have valid values.""" @@ -57,13 +48,3 @@ def on_train_start(self, model: RenateModule) -> None: p.mul_(self._shrink_factor) if self._sigma != 0.0: p.add_(self._sigma * torch.randn(p.size(), device=p.device)) - - def set_shrink_factor(self, shrink_factor: float) -> None: - self._shrink_factor.data = torch.tensor( - shrink_factor, dtype=self._shrink_factor.dtype, device=self._shrink_factor.device - ) - self._verify_attributes() - - def set_sigma(self, sigma: float) -> None: - self._sigma.data = torch.tensor(sigma, dtype=self._sigma.dtype, device=self._sigma.device) - self._verify_attributes() diff --git a/test/conftest.py b/test/conftest.py index 162a64ac..81097979 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -112,14 +112,6 @@ def pytest_collection_modifyitems(config, items): "seed": 1, }, } -LEARNER_HYPERPARAMETER_UPDATES = { - ExperienceReplayLearner: {"batch_size": 128}, - Learner: {"batch_size": 128}, - GDumbLearner: {"batch_size": 128, "memory_size": 50}, - JointLearner: {"batch_size": 128}, - RepeatedDistillationLearner: {"batch_size": 128}, - OfflineExperienceReplayLearner: {"batch_size": 128}, -} AVALANCHE_LEARNER_HYPERPARAMETER_UPDATES = { AvalancheEWCLearner: {"ewc_lambda": 0.3}, AvalancheLwFLearner: {"alpha": 0.2, "temperature": 3}, diff --git a/test/integration_tests/configs/suites/quick/cls-er.json b/test/integration_tests/configs/suites/quick/cls-er.json index 13498c6c..ff3e7ad8 100644 --- a/test/integration_tests/configs/suites/quick/cls-er.json +++ b/test/integration_tests/configs/suites/quick/cls-er.json @@ -5,6 +5,6 @@ "dataset": "mnist.json", "backend": "local", "job_name": "class-incremental-mlp-cls-er", - "expected_accuracy_linux": [[0.9839243292808533, 0.9740450382232666]], + "expected_accuracy_linux": [[0.9834515452384949, 0.9740450382232666]], "expected_accuracy_darwin": [[0.9834515452384949, 0.9740450382232666]] } diff --git a/test/integration_tests/configs/suites/quick/super-er.json b/test/integration_tests/configs/suites/quick/super-er.json index 3831303d..9148ce49 100644 --- a/test/integration_tests/configs/suites/quick/super-er.json +++ b/test/integration_tests/configs/suites/quick/super-er.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-super-er", - "expected_accuracy_linux": [[0.9390000104904175, 0.9120000004768372], [0.9384999871253967, 0.9129999876022339]], - "expected_accuracy_darwin": [[0.9399999976158142, 0.9129999876022339]] + "expected_accuracy_linux": [[0.9394999742507935, 0.9125000238418579], [0.9384999871253967, 0.9135000109672546]], + "expected_accuracy_darwin": [[0.9424999952316284, 0.9100000262260437]] } diff --git a/test/renate/updaters/experimental/test_er.py b/test/renate/updaters/experimental/test_er.py index 1459f0a4..72545b72 100644 --- a/test/renate/updaters/experimental/test_er.py +++ b/test/renate/updaters/experimental/test_er.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import os import pytest import torch @@ -89,125 +88,233 @@ def test_er_validation_buffer(tmpdir): ) +def validate_cls_er(model_updater, learner_kwargs): + assert model_updater._learner._batch_size == learner_kwargs["batch_size"] + assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["alpha"] + assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["beta"] + assert ( + model_updater._learner._components["cls_loss"]._stable_model_update_weight + == learner_kwargs["stable_model_update_weight"] + ) + assert ( + model_updater._learner._components["cls_loss"]._plastic_model_update_weight + == learner_kwargs["plastic_model_update_weight"] + ) + assert ( + model_updater._learner._components["cls_loss"]._stable_model_update_probability + == learner_kwargs["stable_model_update_probability"] + ) + assert ( + model_updater._learner._components["cls_loss"]._plastic_model_update_probability + == learner_kwargs["plastic_model_update_probability"] + ) + + +def validate_dark_er(model_updater, learner_kwargs): + assert model_updater._learner._batch_size == learner_kwargs["batch_size"] + assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["beta"] + assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["alpha"] + + +def validate_pod_er(model_updater, learner_kwargs): + assert model_updater._learner._batch_size == learner_kwargs["batch_size"] + assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + assert model_updater._learner._components["pod_loss"].weight == learner_kwargs["alpha"] + assert ( + model_updater._learner._components["pod_loss"]._distillation_type + == learner_kwargs["distillation_type"] + ) + assert model_updater._learner._components["pod_loss"]._normalize == learner_kwargs["normalize"] + + +def validate_super_er(model_updater, learner_kwargs): + assert model_updater._learner._batch_size == learner_kwargs["batch_size"] + assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["der_beta"] + assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["der_alpha"] + assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["cls_alpha"] + assert ( + model_updater._learner._components["cls_loss"]._stable_model_update_weight + == learner_kwargs["cls_stable_model_update_weight"] + ) + assert ( + model_updater._learner._components["cls_loss"]._plastic_model_update_weight + == learner_kwargs["cls_plastic_model_update_weight"] + ) + assert ( + model_updater._learner._components["cls_loss"]._stable_model_update_probability + == learner_kwargs["cls_stable_model_update_probability"] + ) + assert ( + model_updater._learner._components["cls_loss"]._plastic_model_update_probability + == learner_kwargs["cls_plastic_model_update_probability"] + ) + assert model_updater._learner._components["pod_loss"].weight == learner_kwargs["pod_alpha"] + assert ( + model_updater._learner._components["pod_loss"]._distillation_type + == learner_kwargs["pod_distillation_type"] + ) + assert ( + model_updater._learner._components["pod_loss"]._normalize == learner_kwargs["pod_normalize"] + ) + assert ( + model_updater._learner._components["shrink_perturb"]._shrink_factor + == learner_kwargs["sp_shrink_factor"] + ) + assert model_updater._learner._components["shrink_perturb"]._sigma == learner_kwargs["sp_sigma"] + + @pytest.mark.parametrize( - "cls,kwargs", + "learner_class,validate_fct,learner_kwargs1,learner_kwargs2", [ - [ - ExperienceReplayLearner, - {"alpha": 0.2, "memory_size": 10, "memory_batch_size": 10}, - ], - [ - DarkExperienceReplayLearner, - {"alpha": 0.1, "beta": 0.3, "memory_size": 42}, - ], - [ + ( CLSExperienceReplayLearner, + validate_cls_er, { - "alpha": 0.3, - "beta": 0.1, - "stable_model_update_weight": 0.3, - "plastic_model_update_weight": 0.3, + "memory_size": 30, + "memory_batch_size": 20, + "batch_size": 50, + "seed": 1, + "alpha": 0.123, + "beta": 2, + "stable_model_update_weight": 0.2, + "plastic_model_update_weight": 0.1, "stable_model_update_probability": 0.3, - "plastic_model_update_probability": 0.5, - "memory_size": 42, + "plastic_model_update_probability": 0.4, + }, + { + "memory_size": 30, + "memory_batch_size": 10, + "batch_size": 100, + "seed": 1, + "alpha": 2.3, + "beta": 3, + "stable_model_update_weight": 0.6, + "plastic_model_update_weight": 0.5, + "stable_model_update_probability": 0.7, + "plastic_model_update_probability": 0.8, + }, + ), + ( + DarkExperienceReplayLearner, + validate_dark_er, + { + "memory_size": 30, + "memory_batch_size": 20, + "batch_size": 50, + "seed": 1, + "alpha": 0.123, + "beta": 2, + }, + { + "memory_size": 30, + "memory_batch_size": 10, + "batch_size": 100, + "seed": 1, + "alpha": 2.3, + "beta": 3, }, - ], - [ + ), + ( PooledOutputDistillationExperienceReplayLearner, - {"alpha": 0.3, "distillation_type": "pixel", "normalize": False, "memory_size": 42}, - ], - [ + validate_pod_er, + { + "memory_size": 30, + "memory_batch_size": 20, + "batch_size": 50, + "seed": 1, + "alpha": 0.123, + "distillation_type": "spatial", + "normalize": True, + }, + { + "memory_size": 30, + "memory_batch_size": 10, + "batch_size": 100, + "seed": 1, + "alpha": 0.123, + "distillation_type": "channel", + "normalize": False, + }, + ), + ( SuperExperienceReplayLearner, + validate_super_er, { - "der_alpha": 0.2, - "der_beta": 0.3, - "sp_shrink_factor": 0.1, - "sp_sigma": 0.3, - "cls_alpha": 0.3, - "cls_stable_model_update_weight": 0.4, - "cls_plastic_model_update_weight": 0.4, - "cls_stable_model_update_probability": 0.3, - "cls_plastic_model_update_probability": 0.5, - "pod_alpha": 0.1, - "pod_distillation_type": "pixel", + "memory_size": 30, + "memory_batch_size": 20, + "batch_size": 50, + "seed": 1, + "der_alpha": 0.123, + "der_beta": 2, + "sp_shrink_factor": 0.33, + "sp_sigma": 0.11, + "cls_alpha": 2.3, + "cls_stable_model_update_weight": 0.6, + "cls_plastic_model_update_weight": 0.5, + "cls_stable_model_update_probability": 0.7, + "cls_plastic_model_update_probability": 0.8, + "pod_alpha": 0.13, + "pod_distillation_type": "spatial", + "pod_normalize": True, + }, + { + "memory_size": 30, + "memory_batch_size": 10, + "batch_size": 100, + "seed": 1, + "der_alpha": 2.3, + "der_beta": 3, + "sp_shrink_factor": 0.66, + "sp_sigma": 0.22, + "cls_alpha": 2.3, + "cls_stable_model_update_weight": 0.6, + "cls_plastic_model_update_weight": 0.5, + "cls_stable_model_update_probability": 0.7, + "cls_plastic_model_update_probability": 0.8, + "pod_alpha": 0.123, + "pod_distillation_type": "channel", "pod_normalize": False, - "memory_size": 42, }, - ], + ), ], ) -def test_er_components_save_and_load(tmpdir, cls, kwargs): - """This test saves the learner state and reloads it and verifies that the hyperparameters - for the components were correctly set, saved and loaded.""" - model = pytest.helpers.get_renate_module_mlp( - num_inputs=10, num_outputs=10, hidden_size=32, num_hidden_layers=3 +def test_saving_and_loading_of_er_methods( + tmpdir, learner_class, validate_fct, learner_kwargs1, learner_kwargs2 +): + """ER saving and loading test. + + The ER methods partially have custom saving and loading functions (CLS-ER). Furthermore, the + components used to be Modules that effectively did not allow for changing hyperparameter + settings. + """ + model, train_dataset, _ = pytest.helpers.get_renate_module_mlp_and_data( + num_inputs=10, + num_outputs=10, + hidden_size=32, + num_hidden_layers=1, + train_num_samples=10, + test_num_samples=5, ) - learner = cls( - model=model, - loss_fn=pytest.helpers.get_loss_fn(), - optimizer=pytest.helpers.get_partial_optimizer(), - **kwargs, + state_url = defaults.input_state_folder(tmpdir) + + model_updater = pytest.helpers.get_simple_updater( + model, + learner_class=learner_class, + learner_kwargs=learner_kwargs1, + output_state_folder=state_url, + max_epochs=2, ) - torch.save(learner.state_dict(), os.path.join(tmpdir, "learner.pt")) - learner = cls( - model=model, - loss_fn=pytest.helpers.get_loss_fn(), - optimizer=pytest.helpers.get_partial_optimizer(), - **kwargs, - ) - learner.load_state_dict(torch.load(os.path.join(tmpdir, "learner.pt"))) - if isinstance(learner, ExperienceReplayLearner) and not isinstance( - learner, DarkExperienceReplayLearner - ): - assert learner._components["memory_loss"]._weight == kwargs["alpha"] - if isinstance(learner, DarkExperienceReplayLearner): - assert learner._components["mse_loss"]._weight == kwargs["alpha"] - assert learner._components["memory_loss"]._weight == kwargs["beta"] - elif isinstance(learner, PooledOutputDistillationExperienceReplayLearner): - assert learner._components["pod_loss"]._weight == kwargs["alpha"] - assert learner._components["pod_loss"]._distillation_type == kwargs["distillation_type"] - assert learner._components["pod_loss"]._normalize == kwargs["normalize"] - elif isinstance(learner, CLSExperienceReplayLearner): - assert learner._components["memory_loss"]._weight == kwargs["alpha"] - assert learner._components["cls_loss"]._weight == kwargs["beta"] - assert ( - learner._components["cls_loss"]._stable_model_update_weight - == kwargs["stable_model_update_weight"] - ) - assert ( - learner._components["cls_loss"]._plastic_model_update_weight - == kwargs["plastic_model_update_weight"] - ) - assert ( - learner._components["cls_loss"]._stable_model_update_probability - == kwargs["stable_model_update_probability"] - ) - assert ( - learner._components["cls_loss"]._plastic_model_update_probability - == kwargs["plastic_model_update_probability"] - ) - elif isinstance(learner, SuperExperienceReplayLearner): - assert learner._components["mse_loss"]._weight == kwargs["der_alpha"] - assert learner._components["memory_loss"]._weight == kwargs["der_beta"] - assert learner._components["shrink_perturb"]._shrink_factor == kwargs["sp_shrink_factor"] - assert learner._components["shrink_perturb"]._sigma == kwargs["sp_sigma"] - assert learner._components["cls_loss"]._weight == kwargs["cls_alpha"] - assert ( - learner._components["cls_loss"]._stable_model_update_weight - == kwargs["cls_stable_model_update_weight"] - ) - assert ( - learner._components["cls_loss"]._plastic_model_update_weight - == kwargs["cls_plastic_model_update_weight"] - ) - assert ( - learner._components["cls_loss"]._stable_model_update_probability - == kwargs["cls_stable_model_update_probability"] - ) - assert ( - learner._components["cls_loss"]._plastic_model_update_probability - == kwargs["cls_plastic_model_update_probability"] - ) - assert learner._components["pod_loss"]._weight == kwargs["pod_alpha"] - assert learner._components["pod_loss"]._distillation_type == kwargs["pod_distillation_type"] - assert learner._components["pod_loss"]._normalize == kwargs["pod_normalize"] + + model = model_updater.update(train_dataset, task_id=defaults.TASK_ID) + validate_fct(model_updater, learner_kwargs1) + model_updater = pytest.helpers.get_simple_updater( + model, + learner_class=learner_class, + learner_kwargs=learner_kwargs2, + input_state_folder=state_url, + max_epochs=2, + ) + validate_fct(model_updater, learner_kwargs2) From 2aea6fbf0a967c26c0a70852faebe343ee578c12 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 27 Jul 2023 15:50:18 +0200 Subject: [PATCH 54/89] Add Micro Average Accuracy (#323) --- src/renate/benchmark/experimentation.py | 18 +++++--- .../evaluation/metrics/classification.py | 44 +++++++++++++++++-- src/renate/utils/module.py | 10 ++--- test/renate/benchmark/test_experimentation.py | 37 +++++++++++++++- .../evaluation/metrics/test_classification.py | 22 ++++++++-- 5 files changed, 109 insertions(+), 22 deletions(-) diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index c2e90fc9..b4a0d72b 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -24,6 +24,7 @@ backward_transfer, forgetting, forward_transfer, + micro_average_accuracy, ) from renate.training import run_training_job from renate.training.training import submit_remote_job @@ -54,6 +55,7 @@ def create_cumulative_metrics() -> List[Tuple[str, Callable]]: """ return [ ("Average Accuracy", average_accuracy), + ("Micro Average Accuracy", micro_average_accuracy), ("Forgetting", forgetting), ("Forward Transfer", forward_transfer), ("Backward Transfer", backward_transfer), @@ -64,6 +66,7 @@ def cumulative_metrics_summary( results: Dict[str, List[List[float]]], cumulative_metrics: List[Tuple[str, Callable]], num_tasks: int, + num_instances: List[int], ) -> pd.DataFrame: """Creates a pandas DataFrame summary with respect to the observed tasks, specified by `num_tasks`. @@ -73,12 +76,13 @@ def cumulative_metrics_summary( metrics. cumulative_metrics: The list of (name, metric) tuples. num_tasks: The total number of tasks. + num_instances: Count of test data points for each task. """ data = [] - for task_id in range(num_tasks + 1): + for task_id in range(num_tasks): row = [task_id + 1] for _, metric in cumulative_metrics: - row.append(metric(results, task_id)) + row.append(metric(results, task_id, num_instances)) data.append(row) column_names = ["Task ID"] + [name for name, _ in cumulative_metrics] @@ -300,6 +304,7 @@ def _execute_experiment_job_locally( assert num_updates == len( data_module.test_data() ), f"The dataset has {len(data_module.test_data())} chunks, expected {num_updates}." + num_instances = [len(data_chunk) for data_chunk in data_module.test_data()] transforms = get_transforms_kwargs(config_module, config_space) metrics_fn_kwargs = get_metrics_fn_kwargs(config_module, config_space) metrics = get_metrics(config_module, **metrics_fn_kwargs) @@ -312,9 +317,8 @@ def _execute_experiment_job_locally( # TODO: evaluate's trainer has to use devices=1: # See https://github.com/Lightning-AI/lightning/issues/2537 # The fix is to launch evaluation in a separate process like training. - results: Dict[str, List[List[float]]] = {} - evaluate_and_record_results( - results, + results = evaluate_and_record_results( + {}, model=model, data_module=data_module, transform=transforms.get("test_transform"), @@ -365,7 +369,7 @@ def _execute_experiment_job_locally( **get_model_fn_kwargs(config_module, config_space), ) - evaluate_and_record_results( + results = evaluate_and_record_results( results, model=model, data_module=data_module, @@ -383,7 +387,7 @@ def _execute_experiment_job_locally( logger.info(df) cumulative_metrics = create_cumulative_metrics() - df = cumulative_metrics_summary(results, cumulative_metrics, num_updates - 1) + df = cumulative_metrics_summary(results, cumulative_metrics, num_updates, num_instances) save_pandas_df_to_csv(df, defaults.metric_summary_file(logs_url)) logger.info("### Cumulative results: ###") logger.info(df) diff --git a/src/renate/evaluation/metrics/classification.py b/src/renate/evaluation/metrics/classification.py index dde5fc5e..1fa2d1f2 100644 --- a/src/renate/evaluation/metrics/classification.py +++ b/src/renate/evaluation/metrics/classification.py @@ -3,7 +3,9 @@ from typing import Dict, List -def average_accuracy(results: Dict[str, List[List[float]]], task_id: int) -> float: +def average_accuracy( + results: Dict[str, List[List[float]]], task_id: int, num_instances: List[int] +) -> float: """Compute the average accuracy of a model. This measure is defined by: @@ -18,11 +20,38 @@ def average_accuracy(results: Dict[str, List[List[float]]], task_id: int) -> flo results: The results dictionary holding all the results with respect to all recorded metrics. task_id: The task index. + num_instances: Count of test data points for each update. """ return sum(results["accuracy"][task_id][: task_id + 1]) / (task_id + 1) -def forgetting(results: Dict[str, List[List[float]]], task_id: int) -> float: +def micro_average_accuracy( + results: Dict[str, List[List[float]]], task_id: int, num_instances: List[int] +) -> float: + """Compute the micro average accuracy of a model. + + This measure is defined by the number of correctly classified data points divided by the + total number of data points. If the number of data points is the same in each update step, + this is the same as ``average_accuracy``. + + Args: + results: The results dictionary holding all the results with respect to all recorded + metrics. + task_id: The task index. + num_instances: Count of test data points for each update. + """ + total_num_instances = sum(num_instances[: task_id + 1]) + return sum( + [ + num_instances[i] / total_num_instances * results["accuracy"][task_id][i] + for i in range(task_id + 1) + ] + ) + + +def forgetting( + results: Dict[str, List[List[float]]], task_id: int, num_instances: List[int] +) -> float: """Compute the forgetting measure of the model. This measure is defined by: @@ -43,6 +72,7 @@ def forgetting(results: Dict[str, List[List[float]]], task_id: int) -> float: results: The results dictionary holding all the results with respect to all recorded metrics. task_id: The task index. + num_instances: Count of test data points for each update. """ if task_id == 0: return 0.0 @@ -59,7 +89,9 @@ def f(results: List[List[float]], j: int, i: int) -> float: return sum_f / task_id -def backward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> float: +def backward_transfer( + results: Dict[str, List[List[float]]], task_id: int, num_instances: List[int] +) -> float: """Compute the backward transfer measure of the model. This measure is defined by: @@ -74,6 +106,7 @@ def backward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> fl results: The results dictionary holding all the results with respect to all recorded metrics. task_id: The task index. + num_instances: Count of test data points for each update. """ if task_id == 0: return 0.0 @@ -83,7 +116,9 @@ def backward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> fl ) -def forward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> float: +def forward_transfer( + results: Dict[str, List[List[float]]], task_id: int, num_instances: List[int] +) -> float: """Compute the forward transfer measure of the model. This measure is defined by: @@ -99,6 +134,7 @@ def forward_transfer(results: Dict[str, List[List[float]]], task_id: int) -> flo results: The results dictionary holding all the results with respect to all recorded metrics. task_id: The task index. + num_instances: Count of test data points for each update. """ if task_id == 0: return 0.0 diff --git a/src/renate/utils/module.py b/src/renate/utils/module.py index 6ac35735..9a5fd649 100644 --- a/src/renate/utils/module.py +++ b/src/renate/utils/module.py @@ -56,9 +56,6 @@ def evaluate_and_record_results( precision: Type of bit precision to use. `More details `__ """ - - data_module.setup() - update_results = evaluate( model=model, test_dataset=data_module.test_data(), @@ -74,9 +71,10 @@ def evaluate_and_record_results( precision=precision, ) for key, value in update_results.items(): - if key not in results: - results[key + metric_postfix] = [] - results[key + metric_postfix].append(value) + result_key = f"{key}{metric_postfix}" + if result_key not in results: + results[result_key] = [] + results[result_key].append(value) return results diff --git a/test/renate/benchmark/test_experimentation.py b/test/renate/benchmark/test_experimentation.py index 821b6c46..c6264f97 100644 --- a/test/renate/benchmark/test_experimentation.py +++ b/test/renate/benchmark/test_experimentation.py @@ -5,7 +5,12 @@ import pandas as pd import pytest -from renate.benchmark.experimentation import execute_experiment_job +from renate.benchmark.experimentation import ( + cumulative_metrics_summary, + execute_experiment_job, + individual_metrics_summary, +) +from renate.evaluation.metrics.classification import average_accuracy @pytest.fixture @@ -32,6 +37,7 @@ def test_execute_experiment_job(tmpdir, experiment_job_kwargs, save_state): expected_columns = [ "Task ID", "Average Accuracy", + "Micro Average Accuracy", "Forgetting", "Forward Transfer", "Backward Transfer", @@ -65,3 +71,32 @@ def test_execute_experiment_job_edge_cases(tmpdir, experiment_job_kwargs, update experiment_job_kwargs.update(update_dict) with pytest.raises(AssertionError, match=regex): execute_experiment_job(experiment_outputs_url=tmpdir, **experiment_job_kwargs) + + +def test_cumulative_metrics_summary(): + results = { + "accuracy": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + "accuracy_init": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + metrics = [("Average Accuracy", average_accuracy)] + df = cumulative_metrics_summary( + results=results, + cumulative_metrics=metrics, + num_tasks=3, + num_instances=[10, 20, 30], + ) + assert list(df.columns) == ["Task ID", "Average Accuracy"] + assert pytest.approx(list(df["Average Accuracy"])) == [0.1, 0.45, 0.8] + assert df.shape == (3, 2) + + +def test_individual_metrics_summary(): + results = { + "accuracy": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + "accuracy_init": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]], + } + df = individual_metrics_summary(results=results, current_task=2, num_tasks=3) + assert list(df.columns) == [("Task ID", "")] + [("accuracy", f"Task {i}") for i in range(1, 4)] + assert list(df.iloc[0]) == [1.0, 0.1, 0.2, 0.3] + assert list(df.iloc[1]) == [2.0, 0.4, 0.5, 0.6] + assert df.shape == (2, 4) diff --git a/test/renate/evaluation/metrics/test_classification.py b/test/renate/evaluation/metrics/test_classification.py index c753fa91..9e44b9f9 100644 --- a/test/renate/evaluation/metrics/test_classification.py +++ b/test/renate/evaluation/metrics/test_classification.py @@ -8,6 +8,7 @@ backward_transfer, forgetting, forward_transfer, + micro_average_accuracy, ) @@ -21,7 +22,20 @@ ) def test_average_accuracy(task_id, result): results = SAMPLE_CLASSIFICATION_RESULTS - assert pytest.approx(average_accuracy(results, task_id)) == result + assert pytest.approx(average_accuracy(results, task_id, [])) == result + + +@pytest.mark.parametrize( + "task_id,expected_result,num_instances", + [ + [0, 0.9362000226974487, [10, 5, 10]], + [1, 0.8691666722297668, [10, 5, 10]], + [2, 0.6491599977016449, [10, 5, 10]], + ], +) +def test_micro_average_accuracy(task_id, expected_result, num_instances): + results = SAMPLE_CLASSIFICATION_RESULTS + assert pytest.approx(micro_average_accuracy(results, task_id, num_instances)) == expected_result @pytest.mark.parametrize( @@ -40,7 +54,7 @@ def test_average_accuracy(task_id, result): ) def test_forgetting(task_id, result): results = SAMPLE_CLASSIFICATION_RESULTS - assert pytest.approx(forgetting(results, task_id)) == result + assert pytest.approx(forgetting(results, task_id, [])) == result @pytest.mark.parametrize( @@ -59,7 +73,7 @@ def test_forgetting(task_id, result): ) def test_backward_transfer(task_id, result): results = SAMPLE_CLASSIFICATION_RESULTS - assert pytest.approx(backward_transfer(results, task_id)) == result + assert pytest.approx(backward_transfer(results, task_id, [])) == result @pytest.mark.parametrize( @@ -75,4 +89,4 @@ def test_backward_transfer(task_id, result): ) def test_forward_transfer(task_id, result): results = SAMPLE_CLASSIFICATION_RESULTS - assert pytest.approx(forward_transfer(results, task_id)) == result + assert pytest.approx(forward_transfer(results, task_id, [])) == result From 5ea85b5e2aa1b89335282d53979faa0de8e46441 Mon Sep 17 00:00:00 2001 From: wistuba Date: Tue, 1 Aug 2023 15:41:21 +0200 Subject: [PATCH 55/89] Experimentation Tools (#356) --- benchmarks/README.rst | 35 +++++ .../experiment_configs/datasets/clear10.json | 5 + .../experiment_configs/datasets/clear100.json | 5 + .../datasets/domainnet.json | 4 + .../experiment_configs/datasets/fmow.json | 9 ++ .../experiment_configs/datasets/huffpost.json | 9 ++ .../fine-tuning-clear10.json | 6 + .../fine-tuning-clear100.json | 6 + .../fine-tuning-domainnet.json | 6 + .../experiment_configs/fine-tuning-fmow.json | 6 + .../fine-tuning-huffpost.json | 6 + .../experiment_configs/joint-clear10.json | 6 + .../experiment_configs/joint-clear100.json | 6 + .../experiment_configs/joint-domainnet.json | 6 + .../experiment_configs/models/bert.json | 4 + .../models/resnet18-cifar.json | 3 + .../experiment_configs/models/resnet18.json | 3 + .../offline-er-clear10.json | 6 + .../offline-er-clear100.json | 6 + .../offline-er-domainnet.json | 6 + .../scenarios/clear10-10updates.json | 15 +++ .../scenarios/clear100-11updates.json | 15 +++ .../scenarios/domainnet-6updates.json | 15 +++ .../scenarios/wild-time.json | 4 + .../updaters/fine-tuning-clear.json | 4 + .../updaters/fine-tuning-domainnet.json | 4 + .../updaters/fine-tuning-fmow.json | 11 ++ .../updaters/fine-tuning-huffpost.json | 8 ++ .../updaters/joint-clear.json | 4 + .../updaters/joint-domainnet.json | 4 + .../updaters/offline-er-clear10.json | 6 + .../updaters/offline-er-clear100.json | 6 + .../updaters/offline-er-domainnet.json | 6 + benchmarks/requirements.txt | 2 + benchmarks/run_benchmark.py | 126 ++++++++++++++++++ src/renate/defaults.py | 4 +- 36 files changed, 375 insertions(+), 2 deletions(-) create mode 100644 benchmarks/README.rst create mode 100644 benchmarks/experiment_configs/datasets/clear10.json create mode 100644 benchmarks/experiment_configs/datasets/clear100.json create mode 100644 benchmarks/experiment_configs/datasets/domainnet.json create mode 100644 benchmarks/experiment_configs/datasets/fmow.json create mode 100644 benchmarks/experiment_configs/datasets/huffpost.json create mode 100644 benchmarks/experiment_configs/fine-tuning-clear10.json create mode 100644 benchmarks/experiment_configs/fine-tuning-clear100.json create mode 100644 benchmarks/experiment_configs/fine-tuning-domainnet.json create mode 100644 benchmarks/experiment_configs/fine-tuning-fmow.json create mode 100644 benchmarks/experiment_configs/fine-tuning-huffpost.json create mode 100644 benchmarks/experiment_configs/joint-clear10.json create mode 100644 benchmarks/experiment_configs/joint-clear100.json create mode 100644 benchmarks/experiment_configs/joint-domainnet.json create mode 100644 benchmarks/experiment_configs/models/bert.json create mode 100644 benchmarks/experiment_configs/models/resnet18-cifar.json create mode 100644 benchmarks/experiment_configs/models/resnet18.json create mode 100644 benchmarks/experiment_configs/offline-er-clear10.json create mode 100644 benchmarks/experiment_configs/offline-er-clear100.json create mode 100644 benchmarks/experiment_configs/offline-er-domainnet.json create mode 100644 benchmarks/experiment_configs/scenarios/clear10-10updates.json create mode 100644 benchmarks/experiment_configs/scenarios/clear100-11updates.json create mode 100644 benchmarks/experiment_configs/scenarios/domainnet-6updates.json create mode 100644 benchmarks/experiment_configs/scenarios/wild-time.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-clear.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-domainnet.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-fmow.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json create mode 100644 benchmarks/experiment_configs/updaters/joint-clear.json create mode 100644 benchmarks/experiment_configs/updaters/joint-domainnet.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-clear10.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-clear100.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-domainnet.json create mode 100644 benchmarks/requirements.txt create mode 100644 benchmarks/run_benchmark.py diff --git a/benchmarks/README.rst b/benchmarks/README.rst new file mode 100644 index 00000000..ecbc1afe --- /dev/null +++ b/benchmarks/README.rst @@ -0,0 +1,35 @@ +Running Standard Benchmarks +****************************************************************************** + +This folder contains a couple of files to run experiments with Renate. +Add new experimentation configuration files to add new experiments. +Edit the ``requirements.txt`` to add additional requirements for an experiment (SageMaker only). + +Instructions +============ +1. Clone the repository +2. ``cd Renate`` and install Renate +3. Run a benchmark via + + .. code-block:: bash + + python benchmarks/run_benchmark.py --benchmark-file fine-tuning-clear10.json \ + --backend sagemaker --budget-factor 1 --job-name clear10-finetuning-1 --num-repetitions 1 + + This is an example command to run an experiment with ClEAR10 on SageMaker using a Fine-Tuning + updater. + + Quick explanation of the arguments: + + - ``benchmark-file``: Any filename of files in ``experiment_configs``. This file specifies all + properties of the experiment, i.e., dataset, scenario, updater, and hyperparameters settings. + Modify or add more to run own experiments. + - ``backend``: Run the experiment on SageMaker (``sagemaker``) or locally (``local``). + - ``budget-factor``: Each update run will make ``budget_factor * max_epochs`` many passes over + the new data during training. ``max_epochs`` is typically defines as part of the scenario + .json file. Default: ``1``. + - ``job-name``: Defines the name of the output folder and the name of the SageMaker training + job. + - ``num-repetitions``: The number of times the experiment will be repeated. Only the seed + differs between repetitions. + - ``max-time``: The wall clock time spent per update. Default: 5 days. \ No newline at end of file diff --git a/benchmarks/experiment_configs/datasets/clear10.json b/benchmarks/experiment_configs/datasets/clear10.json new file mode 100644 index 00000000..6638e544 --- /dev/null +++ b/benchmarks/experiment_configs/datasets/clear10.json @@ -0,0 +1,5 @@ +{ + "dataset_name": "CLEAR10", + "num_inputs": 50176, + "num_outputs": 11 +} diff --git a/benchmarks/experiment_configs/datasets/clear100.json b/benchmarks/experiment_configs/datasets/clear100.json new file mode 100644 index 00000000..a07bd5cc --- /dev/null +++ b/benchmarks/experiment_configs/datasets/clear100.json @@ -0,0 +1,5 @@ +{ + "dataset_name": "CLEAR100", + "num_inputs": 50176, + "num_outputs": 100 +} diff --git a/benchmarks/experiment_configs/datasets/domainnet.json b/benchmarks/experiment_configs/datasets/domainnet.json new file mode 100644 index 00000000..f97a6b0b --- /dev/null +++ b/benchmarks/experiment_configs/datasets/domainnet.json @@ -0,0 +1,4 @@ +{ + "dataset_name": "DomainNet", + "num_outputs": 345 +} diff --git a/benchmarks/experiment_configs/datasets/fmow.json b/benchmarks/experiment_configs/datasets/fmow.json new file mode 100644 index 00000000..eb36763b --- /dev/null +++ b/benchmarks/experiment_configs/datasets/fmow.json @@ -0,0 +1,9 @@ +{ + "dataset_name": "fmow", + "src_bucket": "mnemosyne-team-bucket", + "src_object_name": "dataset/wildtime/fmow.hdf5", + "num_inputs": 150528, + "num_outputs": 62, + "num_tasks": 16, + "max_epochs": 50 +} diff --git a/benchmarks/experiment_configs/datasets/huffpost.json b/benchmarks/experiment_configs/datasets/huffpost.json new file mode 100644 index 00000000..af65f008 --- /dev/null +++ b/benchmarks/experiment_configs/datasets/huffpost.json @@ -0,0 +1,9 @@ +{ + "dataset_name": "huffpost", + "src_bucket": "mnemosyne-team-bucket", + "src_object_name": "dataset/wildtime/huffpost.hdf5", + "num_inputs": 0, + "num_outputs": 11, + "num_tasks": 7, + "max_epochs": 5 +} diff --git a/benchmarks/experiment_configs/fine-tuning-clear10.json b/benchmarks/experiment_configs/fine-tuning-clear10.json new file mode 100644 index 00000000..b730510b --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-clear10.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-10updates.json", + "model": "resnet18.json", + "updater": "fine-tuning-clear.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-clear100.json b/benchmarks/experiment_configs/fine-tuning-clear100.json new file mode 100644 index 00000000..7fa409a3 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-clear100.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-11updates.json", + "model": "resnet18.json", + "updater": "fine-tuning-clear.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-domainnet.json b/benchmarks/experiment_configs/fine-tuning-domainnet.json new file mode 100644 index 00000000..04ee59e2 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-domainnet.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-6updates.json", + "model": "resnet18.json", + "updater": "fine-tuning-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-fmow.json b/benchmarks/experiment_configs/fine-tuning-fmow.json new file mode 100644 index 00000000..060b25d4 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-fmow.json @@ -0,0 +1,6 @@ +{ + "scenario": "wild-time.json", + "model": "resnet18.json", + "updater": "fine-tuning-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-huffpost.json b/benchmarks/experiment_configs/fine-tuning-huffpost.json new file mode 100644 index 00000000..b3478c77 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-huffpost.json @@ -0,0 +1,6 @@ +{ + "scenario": "wild-time.json", + "model": "bert.json", + "updater": "fine-tuning-huffpost.json", + "dataset": "huffpost.json" +} diff --git a/benchmarks/experiment_configs/joint-clear10.json b/benchmarks/experiment_configs/joint-clear10.json new file mode 100644 index 00000000..5f160ffc --- /dev/null +++ b/benchmarks/experiment_configs/joint-clear10.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-10updates.json", + "model": "resnet18.json", + "updater": "joint-clear.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/joint-clear100.json b/benchmarks/experiment_configs/joint-clear100.json new file mode 100644 index 00000000..a8a92f45 --- /dev/null +++ b/benchmarks/experiment_configs/joint-clear100.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-11updates.json", + "model": "resnet18.json", + "updater": "joint-clear.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/joint-domainnet.json b/benchmarks/experiment_configs/joint-domainnet.json new file mode 100644 index 00000000..973b9083 --- /dev/null +++ b/benchmarks/experiment_configs/joint-domainnet.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-6updates.json", + "model": "resnet18.json", + "updater": "joint-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/models/bert.json b/benchmarks/experiment_configs/models/bert.json new file mode 100644 index 00000000..55f3ba6e --- /dev/null +++ b/benchmarks/experiment_configs/models/bert.json @@ -0,0 +1,4 @@ +{ + "model_name": "HuggingFaceTransformer", + "pretrained_model_name": "bert-base-uncased" +} diff --git a/benchmarks/experiment_configs/models/resnet18-cifar.json b/benchmarks/experiment_configs/models/resnet18-cifar.json new file mode 100644 index 00000000..9d3c0e9d --- /dev/null +++ b/benchmarks/experiment_configs/models/resnet18-cifar.json @@ -0,0 +1,3 @@ +{ + "model_name": "ResNet18CIFAR" +} diff --git a/benchmarks/experiment_configs/models/resnet18.json b/benchmarks/experiment_configs/models/resnet18.json new file mode 100644 index 00000000..b06643ea --- /dev/null +++ b/benchmarks/experiment_configs/models/resnet18.json @@ -0,0 +1,3 @@ +{ + "model_name": "ResNet18" +} diff --git a/benchmarks/experiment_configs/offline-er-clear10.json b/benchmarks/experiment_configs/offline-er-clear10.json new file mode 100644 index 00000000..a4b3f11a --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-clear10.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-10updates.json", + "model": "resnet18.json", + "updater": "offline-er-clear10.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/offline-er-clear100.json b/benchmarks/experiment_configs/offline-er-clear100.json new file mode 100644 index 00000000..6642915c --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-clear100.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-11updates.json", + "model": "resnet18.json", + "updater": "offline-er-clear100.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/offline-er-domainnet.json b/benchmarks/experiment_configs/offline-er-domainnet.json new file mode 100644 index 00000000..fc9aee14 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-domainnet.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-6updates.json", + "model": "resnet18.json", + "updater": "offline-er-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/scenarios/clear10-10updates.json b/benchmarks/experiment_configs/scenarios/clear10-10updates.json new file mode 100644 index 00000000..04ff3289 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/clear10-10updates.json @@ -0,0 +1,15 @@ +{ + "val_size": 0.1, + "scenario_name": "TimeIncrementalScenario", + "num_tasks": 10, + "max_epochs": 100, + "optimizer": "SGD", + "learning_rate": 0.01, + "momentum": 0.9, + "weight_decay": 1e-5, + "batch_size": 256, + "learning_rate_scheduler": "StepLR", + "learning_rate_scheduler_step_size": 30, + "learning_rate_scheduler_gamma": 0.1, + "learning_rate_scheduler_interval": "epoch" +} diff --git a/benchmarks/experiment_configs/scenarios/clear100-11updates.json b/benchmarks/experiment_configs/scenarios/clear100-11updates.json new file mode 100644 index 00000000..f020fc76 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/clear100-11updates.json @@ -0,0 +1,15 @@ +{ + "val_size": 0.1, + "scenario_name": "TimeIncrementalScenario", + "num_tasks": 11, + "max_epochs": 100, + "optimizer": "SGD", + "learning_rate": 0.01, + "momentum": 0.9, + "weight_decay": 1e-5, + "batch_size": 256, + "learning_rate_scheduler": "StepLR", + "learning_rate_scheduler_step_size": 30, + "learning_rate_scheduler_gamma": 0.1, + "learning_rate_scheduler_interval": "epoch" +} diff --git a/benchmarks/experiment_configs/scenarios/domainnet-6updates.json b/benchmarks/experiment_configs/scenarios/domainnet-6updates.json new file mode 100644 index 00000000..fcc0dc34 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/domainnet-6updates.json @@ -0,0 +1,15 @@ +{ + "val_size": 0.3, + "scenario_name": "DomainNetScenario", + "num_tasks": 6, + "max_epochs": 50, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 50, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0, + "domains": ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] +} diff --git a/benchmarks/experiment_configs/scenarios/wild-time.json b/benchmarks/experiment_configs/scenarios/wild-time.json new file mode 100644 index 00000000..cd9cff4d --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/wild-time.json @@ -0,0 +1,4 @@ +{ + "val_size": 0.1, + "scenario_name": "TimeIncrementalScenario" +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-clear.json b/benchmarks/experiment_configs/updaters/fine-tuning-clear.json new file mode 100644 index 00000000..44a7baa4 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-clear.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 256 +} \ No newline at end of file diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-domainnet.json b/benchmarks/experiment_configs/updaters/fine-tuning-domainnet.json new file mode 100644 index 00000000..b650accf --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-domainnet.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 64 +} \ No newline at end of file diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json b/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json new file mode 100644 index 00000000..77040f12 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json @@ -0,0 +1,11 @@ +{ + "updater": "FineTuning", + "optimizer": "SGD", + "learning_rate": 0.03, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 50, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json b/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json new file mode 100644 index 00000000..bdb46bc2 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json @@ -0,0 +1,8 @@ +{ + "updater": "FineTuning", + "optimizer": "Adam", + "learning_rate": 0.0001, + "momentum": 0.9, + "weight_decay": 0.0, + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-clear.json b/benchmarks/experiment_configs/updaters/joint-clear.json new file mode 100644 index 00000000..fe830532 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-clear.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 256 +} diff --git a/benchmarks/experiment_configs/updaters/joint-domainnet.json b/benchmarks/experiment_configs/updaters/joint-domainnet.json new file mode 100644 index 00000000..dbd3d97c --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-domainnet.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear10.json b/benchmarks/experiment_configs/updaters/offline-er-clear10.json new file mode 100644 index 00000000..6fe89caf --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-clear10.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 128, + "memory_batch_size": 128, + "memory_size": 3300 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear100.json b/benchmarks/experiment_configs/updaters/offline-er-clear100.json new file mode 100644 index 00000000..a6d09005 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-clear100.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 128, + "memory_batch_size": 128, + "memory_size": 10000 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-domainnet.json b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json new file mode 100644 index 00000000..b336094f --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 32, + "memory_batch_size": 32, + "memory_size": 3450 +} diff --git a/benchmarks/requirements.txt b/benchmarks/requirements.txt new file mode 100644 index 00000000..437a2942 --- /dev/null +++ b/benchmarks/requirements.txt @@ -0,0 +1,2 @@ +Renate +wild-time-data diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py new file mode 100644 index 00000000..6f56e0b1 --- /dev/null +++ b/benchmarks/run_benchmark.py @@ -0,0 +1,126 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import os +from pathlib import Path + +import boto3 +from syne_tune.backend.sagemaker_backend.sagemaker_utils import get_execution_role + +from renate.benchmark.experimentation import execute_experiment_job, experiment_config_file + + +def load_config(scenario_file, model_file, updater_file, dataset_file): + cs = {} + for file in [scenario_file, model_file, updater_file, dataset_file]: + with open(file) as f: + cs.update(json.load(f)) + return cs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark-file", + type=str, + required=True, + help="Test suite to run.", + ) + parser.add_argument( + f"--backend", + type=str, + required=True, + choices=["local", "sagemaker"], + help="Whether to run locally or on SageMaker.", + ) + parser.add_argument( + f"--budget-factor", + type=float, + required=True, + help="Use budget_factor * N many fine-tuning epochs, " + "where N is the default number of epochs.", + ) + parser.add_argument( + f"--job-name", + type=str, + required=True, + help="Name of the job.", + ) + parser.add_argument( + f"--num-repetitions", + type=int, + default=1, + help="Number of runs with different seeds.", + ) + parser.add_argument( + f"--max-time", + type=int, + default=5 * 24 * 3600, + help="Maximum execution time.", + ) + args = parser.parse_args() + current_folder = Path(os.path.dirname(__file__)) + configs_folder = current_folder / "experiment_configs" + benchmark_file = configs_folder / args.benchmark_file + if not benchmark_file.is_file(): + raise FileNotFoundError(f"Unknown benchmark file '{benchmark_file}'.") + with open(benchmark_file) as f: + benchmark_config = json.load(f) + config_space = load_config( + os.path.join(configs_folder, "scenarios", benchmark_config["scenario"]), + os.path.join(configs_folder, "models", benchmark_config["model"]), + os.path.join(configs_folder, "updaters", benchmark_config["updater"]), + os.path.join(configs_folder, "datasets", benchmark_config["dataset"]), + ) + config_space["max_epochs"] = int(args.budget_factor * config_space["max_epochs"]) + if "learning_rate_scheduler_step_size" in config_space: + config_space["learning_rate_scheduler_step_size"] = int( + args.budget_factor * config_space["learning_rate_scheduler_step_size"] + ) + if "learning_rate_scheduler_t_max" in config_space: + config_space["learning_rate_scheduler_t_max"] = int( + args.budget_factor * config_space["learning_rate_scheduler_t_max"] + ) + current_folder = Path(os.path.dirname(__file__)) + + role = None if args.backend == "local" else get_execution_role() + + for seed in range(args.num_repetitions): + if args.backend == "local": + experiment_outputs_url = ( + Path("tmp") + / "renate-integration-tests" + / args.test_suite + / args.job_name + / str(seed) + ) + role = None + working_directory = str(Path("tmp") / "renate_working_dir") + else: + AWS_ACCOUNT_ID = boto3.client("sts").get_caller_identity().get("Account") + experiment_outputs_url = ( + f"s3://sagemaker-us-west-2-{AWS_ACCOUNT_ID}/renate-domain-incremental/" + f"{args.job_name}/{seed}" + ) + working_directory = "/tmp/renate_working_dir" + execute_experiment_job( + backend=args.backend, + config_file=experiment_config_file(), + config_space=config_space, + experiment_outputs_url=experiment_outputs_url, + working_directory=working_directory, + mode="max", + metric="val_accuracy", + num_updates=config_space["num_tasks"], + role=role, + instance_type="ml.g4dn.xlarge", + n_workers=1, + max_time=args.max_time, + seed=seed, + job_name=args.job_name[:36], + devices=1, + strategy="ddp", + save_state=False, + requirements_file=str(current_folder / "requirements.txt"), + ) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 8e193270..27eccd04 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -39,8 +39,8 @@ INSTANCE_MAX_TIME = 3 * 24 * 3600 N_WORKERS = 1 INSTANCE_TYPE = "ml.c5.xlarge" -PYTHON_VERSION = "py38" -FRAMEWORK_VERSION = "1.12.0" +PYTHON_VERSION = "py39" +FRAMEWORK_VERSION = "1.13.1" TASK_ID = "default_task" WORKING_DIRECTORY = "renate_working_dir" From d8eefcddc216c745058bf327a9948e3b88ece151 Mon Sep 17 00:00:00 2001 From: wistuba Date: Wed, 2 Aug 2023 10:56:02 +0200 Subject: [PATCH 56/89] Cleanup iCarl arguments and implementation (#358) --- src/renate/cli/parsing_functions.py | 15 ++++++++++++++- src/renate/updaters/avalanche/learner.py | 8 ++++++-- src/renate/updaters/avalanche/model_updater.py | 2 -- test/conftest.py | 1 - 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 427c0bc5..82d20612 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -693,6 +693,19 @@ def _add_avalanche_lwf_learner_arguments(arguments: Dict[str, Dict[str, Any]]) - ) +def _add_avalanche_icarl_learner_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: + """A helper function that adds iCarl arguments.""" + arguments.update( + { + "memory_size": { + "type": int, + "default": defaults.MEMORY_SIZE, + "help": f"Number of exemplars being stored. Default: {defaults.MEMORY_SIZE}.", + } + }, + ) + + def get_function_kwargs(args: argparse.Namespace, function_args: Dict[str, Any]) -> Dict[str, Any]: """Returns the kwargs for a function with defined arguments based on provided values. @@ -909,5 +922,5 @@ def get_scheduler_kwargs( "Avalanche-ER": _add_experience_replay_arguments, "Avalanche-EWC": _add_avalanche_ewc_learner_arguments, "Avalanche-LwF": _add_avalanche_lwf_learner_arguments, - "Avalanche-iCaRL": _add_experience_replay_arguments, + "Avalanche-iCaRL": _add_avalanche_icarl_learner_arguments, } diff --git a/src/renate/updaters/avalanche/learner.py b/src/renate/updaters/avalanche/learner.py index 61da8fec..ff258cae 100644 --- a/src/renate/updaters/avalanche/learner.py +++ b/src/renate/updaters/avalanche/learner.py @@ -123,9 +123,13 @@ def create_avalanche_learner( return self._create_avalanche_learner(plugins=plugins, **kwargs) -class AvalancheICaRLLearner(ReplayLearner, AvalancheLoaderMixin): +class AvalancheICaRLLearner(Learner, AvalancheLoaderMixin): """Renate wrapper around Avalanche ICaRL.""" + def __init__(self, memory_size: int, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._memory_size = memory_size + def create_avalanche_learner( self, optimizer: Optimizer, @@ -154,7 +158,7 @@ def create_avalanche_learner( feature_extractor=self._model.get_backbone(), classifier=self._model.get_predictor(), optimizer=optimizer, - memory_size=self._memory_buffer._max_size, + memory_size=self._memory_size, buffer_transform=None, # TODO fixed_memory=True, train_mb_size=self._batch_size, diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index 92eda453..6f2839a0 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -444,7 +444,6 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -470,7 +469,6 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, "batch_size": batch_size, "seed": seed, } diff --git a/test/conftest.py b/test/conftest.py index 81097979..921f6540 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -107,7 +107,6 @@ def pytest_collection_modifyitems(config, items): }, AvalancheICaRLLearner: { "memory_size": 30, - "memory_batch_size": 20, "batch_size": 50, "seed": 1, }, From 44bffb4e71c1c616213c11f8c75c5da0ad0d5b2a Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 3 Aug 2023 10:59:06 +0200 Subject: [PATCH 57/89] Add DataIncrementalDataModule, DataIncrementalScenario and DomainNet benchmark (#357) --- .../scenarios/clear10-10updates.json | 2 +- .../scenarios/clear100-11updates.json | 2 +- .../scenarios/domainnet-6updates.json | 4 +- .../scenarios/wild-time.json | 2 +- doc/benchmarking/renate_benchmarks.rst | 26 +++- src/renate/benchmark/datasets/base.py | 45 +++++++ .../benchmark/datasets/vision_datasets.py | 125 +++++++++++++++++- .../benchmark/datasets/wild_time_data.py | 8 +- src/renate/benchmark/experiment_config.py | 51 +++++-- src/renate/benchmark/scenarios.py | 43 +++--- .../benchmark/test_experimentation_config.py | 38 ++++-- test/renate/benchmark/test_scenarios.py | 10 +- 12 files changed, 293 insertions(+), 63 deletions(-) create mode 100644 src/renate/benchmark/datasets/base.py diff --git a/benchmarks/experiment_configs/scenarios/clear10-10updates.json b/benchmarks/experiment_configs/scenarios/clear10-10updates.json index 04ff3289..46c52936 100644 --- a/benchmarks/experiment_configs/scenarios/clear10-10updates.json +++ b/benchmarks/experiment_configs/scenarios/clear10-10updates.json @@ -1,6 +1,6 @@ { "val_size": 0.1, - "scenario_name": "TimeIncrementalScenario", + "scenario_name": "DataIncrementalScenario", "num_tasks": 10, "max_epochs": 100, "optimizer": "SGD", diff --git a/benchmarks/experiment_configs/scenarios/clear100-11updates.json b/benchmarks/experiment_configs/scenarios/clear100-11updates.json index f020fc76..8fafd4c4 100644 --- a/benchmarks/experiment_configs/scenarios/clear100-11updates.json +++ b/benchmarks/experiment_configs/scenarios/clear100-11updates.json @@ -1,6 +1,6 @@ { "val_size": 0.1, - "scenario_name": "TimeIncrementalScenario", + "scenario_name": "DataIncrementalScenario", "num_tasks": 11, "max_epochs": 100, "optimizer": "SGD", diff --git a/benchmarks/experiment_configs/scenarios/domainnet-6updates.json b/benchmarks/experiment_configs/scenarios/domainnet-6updates.json index fcc0dc34..411e7224 100644 --- a/benchmarks/experiment_configs/scenarios/domainnet-6updates.json +++ b/benchmarks/experiment_configs/scenarios/domainnet-6updates.json @@ -1,6 +1,6 @@ { "val_size": 0.3, - "scenario_name": "DomainNetScenario", + "scenario_name": "DataIncrementalScenario", "num_tasks": 6, "max_epochs": 50, "optimizer": "SGD", @@ -11,5 +11,5 @@ "learning_rate_scheduler_interval": "step", "momentum": 0.0, "weight_decay": 0.0, - "domains": ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] + "data_ids": ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] } diff --git a/benchmarks/experiment_configs/scenarios/wild-time.json b/benchmarks/experiment_configs/scenarios/wild-time.json index cd9cff4d..ec2b116b 100644 --- a/benchmarks/experiment_configs/scenarios/wild-time.json +++ b/benchmarks/experiment_configs/scenarios/wild-time.json @@ -1,4 +1,4 @@ { "val_size": 0.1, - "scenario_name": "TimeIncrementalScenario" + "scenario_name": "DataIncrementalScenario" } diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index fcb8c82a..b97501a7 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -133,6 +133,18 @@ The following table contains the list of supported datasets. - Image Classification - 50k train, 10k test, 100 classes, image shape 32x32x3 - Alex Krizhevsky: Learning Multiple Layers of Features from Tiny Images. 2009. + * - CLEAR10 + - Image Classification + - 10 different datasets, one for each year. Each with 3,300 train, 550 test, 11 classes + - Zhiqiu Lin et al.: The CLEAR Benchmark: Continual LEArning on Real-World Imagery. NeurIPS Datasets and Benchmarks 2021. + * - CLEAR100 + - Image Classification + - 11 different datasets, one for each year. Each with roughly 10k train, 5k test, 100 classes + - Zhiqiu Lin et al.: The CLEAR Benchmark: Continual LEArning on Real-World Imagery. NeurIPS Datasets and Benchmarks 2021. + * - DomainNet + - Image Classification + - 6 datasets from different domains. 345 classes, number of train and test image varies + - Xingchao Peng et al.: Moment Matching for Multi-Source Domain Adaptation. ICCV 2019. * - FashionMNIST - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 @@ -184,13 +196,13 @@ The first part contains all instances with classes 1 and 2, the second with clas * - Scenario Name - Description - Settings - * - :py:class:`~renate.benchmark.scenarios.TimeIncrementalScenario` - - Used in combination only with Wild-Time datasets or CLEAR. - Data is presented time step by time step and the model is evaluated on test data up to the - current time step. - This means that for the Wild-Time datasets, is a different scenario than in the original - Wild-Time data paper. - - * :code:`num_tasks`: Number of data partitions. + * - :py:class:`~renate.benchmark.scenarios.DataIncrementalScenario` + - Used in combination only with :py:class:`~renate.benchmark.datasets.base.DataIncrementalDataModule`, + e.g., Wild-Time datasets, CLEAR, or DomainNet. + Data is presented data by data, where the data could represent a domain or a time slice. + - * :code:`num_tasks`: You can provide this argument if the different datasets are identified by + ids 0 to `num_tasks`. This is the case for time-incremental datasets such as CLEAR or Wild-Time. + * :code:`data_ids`: List of data identifiers. Used for DomainNet to select order or subset of domains. * - :py:class:`~renate.benchmark.scenarios.ClassIncrementalScenario` - Creates data partitions by splitting the data according to class labels. - * :code:`class_groupings`: Tuple of tuples containing the class labels, e.g., ``((1, ), (2, 3, 4))``. diff --git a/src/renate/benchmark/datasets/base.py b/src/renate/benchmark/datasets/base.py new file mode 100644 index 00000000..83b3cc39 --- /dev/null +++ b/src/renate/benchmark/datasets/base.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from abc import ABC +from pathlib import Path +from typing import Optional, Union + +from renate import defaults +from renate.data.data_module import RenateDataModule + + +class DataIncrementalDataModule(RenateDataModule, ABC): + """Base class for all :py:class:`~renate.data.data_module.RenateDataModule` compatible with + :py:class:`~renate.benchmark.scenarios.DataIncrementalScenario`. + + Defines the API required by the :py:class:`~renate.benchmark.scenarios.DataIncrementalScenario`. + All classes extending this class must load the datasets corresponding to the value in + ``data_id`` whenever ``setup()`` is called. + + Args: + data_path: the path to the folder containing the dataset files. + data_id: Time slice to be loaded. + src_bucket: the name of the s3 bucket. If not provided, downloads the data from original + source. + src_object_name: the folder path in the s3 bucket. + val_size: Fraction of the training data to be used for validation. + seed: Seed used to fix random number generation. + """ + + def __init__( + self, + data_path: Union[Path, str], + data_id: Union[int, str], + src_bucket: Optional[str] = None, + src_object_name: Optional[str] = None, + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__( + data_path=data_path, + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) + self.data_id = data_id diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 215e978d..11d5c90b 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -5,14 +5,16 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import pandas as pd import torch import torchvision from torchvision import transforms from renate import defaults +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data import ImageDataset from renate.data.data_module import RenateDataModule -from renate.utils.file import download_and_unzip_file, download_folder_from_s3 +from renate.utils.file import download_and_unzip_file, download_file, download_folder_from_s3 class TinyImageNetDataModule(RenateDataModule): @@ -187,7 +189,7 @@ def to_long(x): return torch.tensor(x, dtype=torch.long) -class CLEARDataModule(RenateDataModule): +class CLEARDataModule(DataIncrementalDataModule): """Datamodule that process CLEAR datasets: CLEAR10 and CLEAR100. Source: https://clear-benchmark.github.io/. @@ -227,6 +229,7 @@ def __init__( ): super(CLEARDataModule, self).__init__( data_path=data_path, + data_id=time_step, src_bucket=src_bucket, src_object_name=src_object_name, val_size=val_size, @@ -234,8 +237,7 @@ def __init__( ) self._dataset_name = dataset_name.lower() assert self._dataset_name in ["clear10", "clear100"] - assert 0 <= time_step <= (9 if self._dataset_name == "clear10" else 10) - self.time_step = time_step + assert 0 <= self.data_id <= (9 if self._dataset_name == "clear10" else 10) def prepare_data(self) -> None: """Download CLEAR dataset with given dataset_name (clear10/clear100).""" @@ -255,7 +257,7 @@ def prepare_data(self) -> None: def setup(self) -> None: """Set up train, test and val datasets.""" - time_step = self.time_step + 1 if self._dataset_name == "clear10" else self.time_step + time_step = self.data_id + 1 if self._dataset_name == "clear10" else self.data_id X, y = self._get_filepaths_and_labels(train=True, time_step=time_step) train_data = ImageDataset(X, y, transform=transforms.ToTensor()) self._train_data, self._val_data = self._split_train_val_data(train_data) @@ -286,3 +288,116 @@ def _get_filepaths_and_labels(self, train: bool, time_step: int) -> Tuple[List[s labels.append(label) return image_paths, labels + + +class DomainNetDataModule(DataIncrementalDataModule): + """Datamodule that provides access to DomainNet. + + Args: + data_path: the path to the folder containing the dataset files. + src_bucket: the name of the s3 bucket. If not provided, downloads the data from original + source. + src_object_name: the folder path in the s3 bucket. + domain: DomainNet domain name, options are clipart, infograph, painting, quickdraw, real, + and sketch. + val_size: Fraction of the training data to be used for validation. + seed: Seed used to fix random number generation. + """ + + md5s = { + "clipart.zip": "cd0d8f2d77a4e181449b78ed62bccf1e", + "clipart_train.txt": "b4349693a7f9c05c53955725c47ed6cb", + "clipart_test.txt": "f5ddbcfd657a3acf9d0f7da10db22565", + "infograph.zip": "720380b86f9e6ab4805bb38b6bd135f8", + "infograph_train.txt": "379b50054f4ac2018dca4f89421b92d9", + "infograph_test.txt": "779626b50869edffe8ea6941c3755c71", + "painting.zip": "1ae32cdb4f98fe7ab5eb0a351768abfd", + "painting_train.txt": "b732ced3939ac8efdd8c0a889dca56cc", + "painting_test.txt": "c1a828fdfe216fb109f1c0083a252c6f", + "quickdraw.zip": "bdc1b6f09f277da1a263389efe0c7a66", + "quickdraw_train.txt": "b4349693a7f9c05c53955725c47ed6cb", + "quickdraw_test.txt": "f5ddbcfd657a3acf9d0f7da10db22565", + "real.zip": "dcc47055e8935767784b7162e7c7cca6", + "real_train.txt": "8ebf02c2075fadd564705f0dc7cd6291", + "real_test.txt": "6098816791c3ebed543c71ffa11b9054", + "sketch.zip": "658d8009644040ff7ce30bb2e820850f", + "sketch_train.txt": "1233bd18aa9a8a200bf4cecf1c34ef3e", + "sketch_test.txt": "d8a222e4672cfd585298aa14d02ea441", + } + + domains = ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] + dataset_stats = { + "clipart": {"mean": [0.7395, 0.7195, 0.6865], "std": [0.3621, 0.3640, 0.3873]}, + "infograph": {"mean": [0.6882, 0.6962, 0.6644], "std": [0.3328, 0.3095, 0.3277]}, + "painting": {"mean": [0.5737, 0.5456, 0.5067], "std": [0.3079, 0.3003, 0.3161]}, + "quickdraw": {"mean": [0.9525, 0.9525, 0.9525], "std": [0.2127, 0.2127, 0.2127]}, + "real": {"mean": [0.6066, 0.5897, 0.5564], "std": [0.3335, 0.3270, 0.3485]}, + "sketch": {"mean": [0.8325, 0.8269, 0.8180], "std": [0.2723, 0.2747, 0.2801]}, + "all": {"mean": [0.7491, 0.7391, 0.7179], "std": [0.3318, 0.3314, 0.3512]}, + } + + def __init__( + self, + data_path: Union[Path, str], + src_bucket: Optional[str] = None, + src_object_name: Optional[str] = None, + domain: str = "clipart", + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__( + data_path=data_path, + data_id=domain.lower(), + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) + assert self.data_id in self.domains, f"Unknown domain {self.data_id}." + + def prepare_data(self) -> None: + """Download DomainNet dataset for given domain.""" + file_name = f"{self.data_id}.zip" + url = "http://csr.bu.edu/ftp/visda/2019/multi-source/" + if self.data_id in ["clipart", "painting"]: + url = os.path.join(url, "groundtruth") + if not self._verify_file(file_name): + download_and_unzip_file( + self.data_id, + self._data_path, + self._src_bucket, + self._src_object_name, + url, + file_name, + ) + for file_name in [f"{self.data_id}_train.txt", f"{self.data_id}_test.txt"]: + if not self._verify_file(file_name): + download_file( + self.data_id, + self._data_path, + self._src_bucket, + self._src_object_name, + "http://csr.bu.edu/ftp/visda/2019/multi-source/domainnet/txt/", + file_name, + ) + + def setup(self) -> None: + """Set up train, test and val datasets.""" + X, y = self._get_filepaths_and_labels("train") + train_data = ImageDataset(X, y, transform=transforms.ToTensor()) + self._train_data, self._val_data = self._split_train_val_data(train_data) + X, y = self._get_filepaths_and_labels("test") + self._test_data = ImageDataset(X, y, transform=transforms.ToTensor()) + + def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]: + """Extracts all the filepaths and labels for a given split.""" + path = os.path.join(self._data_path, self.data_id) + df = pd.read_csv( + os.path.join(path, f"{self.data_id}_{split}.txt"), + sep=" ", + header=None, + names=["path", "label"], + ) + data = list(df.path.apply(lambda x: os.path.join(path, x))) + labels = list(df.label) + return data, labels diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py index 85f0857d..0e9056d7 100644 --- a/src/renate/benchmark/datasets/wild_time_data.py +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -6,12 +6,12 @@ from transformers import PreTrainedTokenizer from renate import defaults -from renate.data.data_module import RenateDataModule +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.utils.file import download_folder_from_s3 from renate.utils.hf_utils import DataCollatorWithPaddingForWildTime -class WildTimeDataModule(RenateDataModule): +class WildTimeDataModule(DataIncrementalDataModule): """Data module wrapping around the Wild-Time data. Huaxiu Yao, Caroline Choi, Bochuan Cao, Yoonho Lee, Pang Wei Koh, Chelsea Finn: @@ -46,13 +46,13 @@ def __init__( ): super().__init__( data_path=data_path, + data_id=time_step, src_bucket=src_bucket, src_object_name=src_object_name, val_size=val_size, seed=seed, ) self._dataset_name = dataset_name - self.time_step = time_step self._tokenizer = tokenizer self._tokenizer_kwargs = tokenizer_kwargs @@ -87,7 +87,7 @@ def setup(self) -> None: kwargs = { "dataset_name": self._dataset_name, - "time_step": available_time_steps(self._dataset_name)[self.time_step], + "time_step": available_time_steps(self._dataset_name)[self.data_id], "data_dir": self._data_path, "in_memory": self._dataset_name != "fmow", } diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index a09d0b7f..74f5984f 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -6,13 +6,17 @@ import torch import wild_time_data from torch.optim import Optimizer -from torch.optim.lr_scheduler import StepLR, _LRScheduler +from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, _LRScheduler from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import transforms from transformers import AutoTokenizer from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule -from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule +from renate.benchmark.datasets.vision_datasets import ( + CLEARDataModule, + DomainNetDataModule, + TorchVisionDataModule, +) from renate.benchmark.datasets.wild_time_data import WildTimeDataModule from renate.benchmark.models import ( MultiLayerPerceptron, @@ -32,13 +36,13 @@ from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer from renate.benchmark.scenarios import ( ClassIncrementalScenario, + DataIncrementalScenario, FeatureSortingScenario, HueShiftScenario, IIDScenario, ImageRotationScenario, PermutationScenario, Scenario, - TimeIncrementalScenario, ) from renate.data.data_module import RenateDataModule from renate.models import RenateModule @@ -131,7 +135,14 @@ def get_data_module( if pretrained_model_name is not None: data_module_kwargs["tokenizer"] = AutoTokenizer.from_pretrained(pretrained_model_name) return WildTimeDataModule(**data_module_kwargs) - + if dataset_name == "DomainNet": + return DomainNetDataModule( + data_path=data_path, + src_bucket=src_bucket, + src_object_name=src_object_name, + val_size=val_size, + seed=seed, + ) if dataset_name.startswith("hfd-"): tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) return HuggingFaceTextDataModule( @@ -157,6 +168,7 @@ def get_scenario( input_dim: Optional[Union[List[int], Tuple[int], int]] = None, feature_idx: Optional[int] = None, randomness: Optional[float] = None, + data_ids: Optional[List[Union[int, str]]] = None, ) -> Scenario: """Function to create scenario based on name and arguments. @@ -172,6 +184,7 @@ def get_scenario( input_dim: Used for scenario `PermutationScenario`. Input dimensionality. feature_idx: Used for scenario `SoftSortingScenario`. Index of feature to sort by. randomness: Used for all `_SortingScenario`. Randomness strength in [0, 1]. + data_ids: List of data_ids for the `DataIncrementalScenario`. Returns: An instance of the requested scenario. @@ -221,9 +234,11 @@ def get_scenario( chunk_id=chunk_id, seed=seed, ) - if scenario_name == "TimeIncrementalScenario": - return TimeIncrementalScenario( - data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed + if scenario_name == "DataIncrementalScenario": + if data_ids is None: + data_ids = [data_id for data_id in range(num_tasks)] + return DataIncrementalScenario( + data_module=data_module, data_ids=data_ids, chunk_id=chunk_id, seed=seed ) raise ValueError(f"Unknown scenario `{scenario_name}`.") @@ -252,6 +267,7 @@ def data_module_fn( pretrained_model_name: Optional[str] = None, input_column: Optional[str] = None, target_column: Optional[str] = None, + data_ids: Optional[List[Union[int, str]]] = None, ): data_module = get_data_module( data_path=data_path, @@ -277,6 +293,7 @@ def data_module_fn( input_dim=input_dim, feature_idx=feature_idx, randomness=randomness, + data_ids=data_ids, ) @@ -291,6 +308,11 @@ def _get_normalize_transform(dataset_name): CLEARDataModule.dataset_stats[dataset_name]["mean"], CLEARDataModule.dataset_stats[dataset_name]["std"], ) + if dataset_name == "DomainNet": + return transforms.Normalize( + DomainNetDataModule.dataset_stats["all"]["mean"], + DomainNetDataModule.dataset_stats["all"]["std"], + ) def train_transform(dataset_name: str) -> Optional[Callable]: @@ -308,7 +330,7 @@ def train_transform(dataset_name: str) -> Optional[Callable]: _get_normalize_transform(dataset_name), ] ) - if dataset_name in ["CLEAR10", "CLEAR100"]: + if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ transforms.Resize( @@ -330,7 +352,7 @@ def test_transform(dataset_name: str) -> Optional[Callable]: return None if dataset_name in ["CIFAR10", "CIFAR100"]: return _get_normalize_transform(dataset_name) - if dataset_name in ["CLEAR10", "CLEAR100"]: + if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ transforms.Resize( @@ -348,6 +370,8 @@ def lr_scheduler_fn( learning_rate_scheduler_step_size: int = 30, learning_rate_scheduler_gamma: float = 0.1, learning_rate_scheduler_interval: str = "epoch", + learning_rate_scheduler_t_max: Optional[int] = None, + learning_rate_scheduler_eta_min: float = 0, ) -> Tuple[Optional[Callable[[Optimizer], _LRScheduler]], str]: if learning_rate_scheduler == "StepLR": return ( @@ -358,6 +382,15 @@ def lr_scheduler_fn( ), learning_rate_scheduler_interval, ) + elif learning_rate_scheduler == "CosineAnnealingLR": + return ( + partial( + CosineAnnealingLR, + T_max=learning_rate_scheduler_t_max, + eta_min=learning_rate_scheduler_eta_min, + ), + learning_rate_scheduler_interval, + ) elif learning_rate_scheduler is None: return None, learning_rate_scheduler_interval raise ValueError(f"Unknown scheduler `{learning_rate_scheduler}`.") diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index d0384e53..70725423 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -9,8 +9,7 @@ from torchvision.transforms import Lambda, RandomRotation, ToPILImage from renate import defaults -from renate.benchmark.datasets.vision_datasets import CLEARDataModule -from renate.benchmark.datasets.wild_time_data import WildTimeDataModule +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data.data_module import RenateDataModule from renate.data.datasets import _TransformedDataset from renate.utils.pytorch import get_generator, randomly_split_data @@ -388,17 +387,17 @@ def _get_scores(self, dataset: Dataset) -> List[float]: return scores -class TimeIncrementalScenario(Scenario): - """Creating a time-incremental scenario for specific datasets. +class DataIncrementalScenario(Scenario): + """Creating a scenario which iterates over pre-defined datasets. - Supports the Wild Time datasets and CLEAR. - DataModules that want to use the TimeIncrementalScenario, need to have an attribute - ``time_step``. Setting this variable and then calling ``setup()`` should load the time-specific - datasets. + The scenario will iterate over a list of datasets that are provided by the given ``DataModule``. + The data is loaded by assigning ``data_ids[chunk_id]`` to the attribute of the ``DataModule`` + with name ``domain`` and then calling its ``setup()`` function. Args: - data_module: The source RenateDataModule for the user data. - num_tasks: The total number of expected tasks for experimentation. + data_module: The source :py:class:`~renate.data.data_module.RenateDataModule` for the user + data. + data_ids: Unique identifier for each pre-defined dataset. chunk_id: The data chunk to load in for the training or validation data. seed: Seed used to fix random number generation. """ @@ -406,24 +405,34 @@ class TimeIncrementalScenario(Scenario): def __init__( self, data_module: RenateDataModule, - num_tasks: int, + data_ids: List[Union[int, str]], chunk_id: int, seed: int = defaults.SEED, ) -> None: - super().__init__(data_module=data_module, num_tasks=num_tasks, chunk_id=chunk_id, seed=seed) - if not isinstance(data_module, (CLEARDataModule, WildTimeDataModule)): + super().__init__( + data_module=data_module, num_tasks=len(data_ids), chunk_id=chunk_id, seed=seed + ) + if not isinstance(data_module, DataIncrementalDataModule): raise ValueError( - "This scenario is only compatible with `CLEARDataModule` and `WildTimeDataModule`." + "This scenario is only compatible with classes that extend " + "`DataIncrementalDataModule`." ) + self._data_ids = data_ids + + def prepare_data(self) -> None: + """Downloads datasets.""" + for data_id in self._data_ids: + self._data_module.data_id = data_id + self._data_module.prepare_data() def setup(self) -> None: """Sets up the scenario.""" - self._data_module.time_step = self._chunk_id + self._data_module.data_id = self._data_ids[self._chunk_id] super().setup() self._train_data = self._data_module.train_data() self._val_data = self._data_module.val_data() self._test_data = [] - for i in range(self._num_tasks): - self._data_module.time_step = i + for data_id in self._data_ids: + self._data_module.data_id = data_id self._data_module.setup() self._test_data.append(self._data_module.test_data()) diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index ca0111b2..659fb042 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -3,7 +3,7 @@ import pytest from torch.nn import Linear from torch.optim import SGD -from torch.optim.lr_scheduler import StepLR +from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import Compose, Normalize @@ -23,12 +23,12 @@ ) from renate.benchmark.scenarios import ( ClassIncrementalScenario, + DataIncrementalScenario, FeatureSortingScenario, HueShiftScenario, IIDScenario, ImageRotationScenario, PermutationScenario, - TimeIncrementalScenario, ) from renate.models.prediction_strategies import ICaRLClassificationStrategy @@ -200,21 +200,28 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): HueShiftScenario, 3, ), - ("TimeIncrementalScenario", "CLEAR10", {"num_tasks": 5}, TimeIncrementalScenario, 5), + ("DataIncrementalScenario", "CLEAR10", {"num_tasks": 5}, DataIncrementalScenario, 5), ( - "TimeIncrementalScenario", + "DataIncrementalScenario", "arxiv", {"num_tasks": 3, "pretrained_model_name": "distilbert-base-uncased"}, - TimeIncrementalScenario, + DataIncrementalScenario, 3, ), ( - "TimeIncrementalScenario", + "DataIncrementalScenario", "fmow", {}, - TimeIncrementalScenario, + DataIncrementalScenario, 16, ), + ( + "DataIncrementalScenario", + "DomainNet", + {"data_ids": ["clipart", "infograph"]}, + DataIncrementalScenario, + 2, + ), ), ids=[ "class_incremental", @@ -226,6 +233,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "time_with_clear", "wild_time_text_with_tokenizer", "wild_time_image_all_tasks", + "domainnet", ], ) @pytest.mark.parametrize("val_size", (0, 0.5), ids=["no_val", "val"]) @@ -255,10 +263,10 @@ def test_data_module_fn( assert scenario._randomness == scenario_kwargs["randomness"] elif expected_scenario_class == HueShiftScenario: assert scenario._randomness == scenario_kwargs["randomness"] - elif expected_scenario_class == TimeIncrementalScenario: + elif expected_scenario_class == DataIncrementalScenario: if "pretrained_model_name" in scenario_kwargs: assert scenario._data_module._tokenizer is not None - elif dataset_name not in ["CLEAR10", "CLEAR100"]: + elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet"]: assert scenario._data_module._tokenizer is None assert scenario._num_tasks == expected_num_tasks @@ -271,6 +279,7 @@ def test_data_module_fn( ("CIFAR10", True, False), ("CIFAR100", True, False), ("CLEAR10", True, True), + ("DomainNet", True, True), ("hfd-rotten_tomatoes", False, False), ), ) @@ -297,10 +306,17 @@ def test_transforms_fails_for_unknown_dataset(): @pytest.mark.parametrize( "learning_rate_scheduler,expected_lr_class,expected_interval", - (("StepLR", StepLR, "epoch"), (None, None, "epoch")), + ( + ("StepLR", StepLR, "epoch"), + ("CosineAnnealingLR", CosineAnnealingLR, "step"), + (None, None, "epoch"), + ), ) def test_lr_scheduler_fn(learning_rate_scheduler, expected_lr_class, expected_interval): - scheduler, interval = lr_scheduler_fn(learning_rate_scheduler) + scheduler, interval = lr_scheduler_fn( + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=expected_interval, + ) assert interval == expected_interval if learning_rate_scheduler is None: assert scheduler is None diff --git a/test/renate/benchmark/test_scenarios.py b/test/renate/benchmark/test_scenarios.py index a53035cc..60950282 100644 --- a/test/renate/benchmark/test_scenarios.py +++ b/test/renate/benchmark/test_scenarios.py @@ -11,11 +11,11 @@ from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.benchmark.scenarios import ( ClassIncrementalScenario, + DataIncrementalScenario, FeatureSortingScenario, IIDScenario, ImageRotationScenario, PermutationScenario, - TimeIncrementalScenario, ) from renate.utils.pytorch import randomly_split_data @@ -79,12 +79,12 @@ def test_class_incremental_scenario_class_grouping_error(): scenario.setup() -def test_time_incremental_scenario_init_error(): - """Check that TimeIncrementalScenario raises Exception for unsupported DataModule.""" +def test_data_incremental_scenario_init_error(): + """Check that DataIncrementalScenario raises Exception for unsupported DataModule.""" with pytest.raises(ValueError, match=r"This scenario is only compatible with*"): - TimeIncrementalScenario( + DataIncrementalScenario( data_module=DummyTorchVisionDataModule(), - num_tasks=2, + data_ids=[0, 1], chunk_id=0, ) From df054f975e4d3d0956250035dbaf8e840064dabf Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Thu, 3 Aug 2023 16:50:10 +0200 Subject: [PATCH 58/89] Adding flags to expose gradient clipping args in Trainer (#361) --- benchmarks/run_benchmark.py | 6 +----- .../class_incremental_learning_cifar10_der.py | 1 + src/renate/benchmark/experimentation.py | 10 ++++++++++ src/renate/cli/parsing_functions.py | 13 ++++++++++++ src/renate/cli/run_training.py | 2 ++ src/renate/defaults.py | 2 ++ src/renate/training/training.py | 10 ++++++++++ .../updaters/avalanche/model_updater.py | 16 +++++++++++++++ src/renate/updaters/experimental/er.py | 20 +++++++++++++++++++ .../updaters/experimental/fine_tuning.py | 4 ++++ src/renate/updaters/experimental/gdumb.py | 4 ++++ src/renate/updaters/experimental/joint.py | 4 ++++ .../updaters/experimental/offline_er.py | 4 ++++ src/renate/updaters/model_updater.py | 9 +++++++++ 14 files changed, 100 insertions(+), 5 deletions(-) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 6f56e0b1..deb80e8f 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -89,11 +89,7 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): for seed in range(args.num_repetitions): if args.backend == "local": experiment_outputs_url = ( - Path("tmp") - / "renate-integration-tests" - / args.test_suite - / args.job_name - / str(seed) + Path("tmp") / "renate-integration-tests" / args.job_name / str(seed) ) role = None working_directory = str(Path("tmp") / "renate_working_dir") diff --git a/examples/benchmarking/class_incremental_learning_cifar10_der.py b/examples/benchmarking/class_incremental_learning_cifar10_der.py index fc8938cc..74c182a7 100644 --- a/examples/benchmarking/class_incremental_learning_cifar10_der.py +++ b/examples/benchmarking/class_incremental_learning_cifar10_der.py @@ -22,6 +22,7 @@ "dataset_name": "CIFAR10", "val_size": 0, "class_groupings": ((0, 1), (2, 3), (4, 5), (6, 7), (8, 9)), + "num_outputs": 10, } for seed in range(10): diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index b4a0d72b..31722498 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -146,6 +146,8 @@ def execute_experiment_job( accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, devices: int = defaults.DEVICES, deterministic_trainer: bool = True, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, job_name: str = defaults.JOB_NAME, strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, @@ -216,6 +218,8 @@ def execute_experiment_job( strategy=strategy, precision=precision, save_state=save_state, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) _execute_experiment_job_remotely( job_name=job_name, @@ -235,6 +239,8 @@ def execute_experiment_job( accelerator=accelerator, devices=devices, deterministic_trainer=deterministic_trainer, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, seed=seed, requirements_file=requirements_file, role=role, @@ -267,6 +273,8 @@ def _execute_experiment_job_locally( strategy: str, precision: str, save_state: bool, + gradient_clip_val: Optional[float], + gradient_clip_algorithm: Optional[str], ) -> None: """Runs an experiment, combining hyperparameter tuning and model for multiple updates. @@ -359,6 +367,8 @@ def _execute_experiment_job_locally( precision=precision, strategy=strategy, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) move_to_uri(output_state_url, input_state_url) if save_state: diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 82d20612..995e7d22 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -311,6 +311,19 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: "argument_group": OPTIONAL_ARGS_GROUP, "true_type": bool, }, + "gradient_clip_val": { + "type": lambda x: None if x == "None" else x, + "default": defaults.GRADIENT_CLIP_VAL, + "help": "The value at which to clip gradients. None disables clipping.", + "argument_group": OPTIONAL_ARGS_GROUP, + }, + "gradient_clip_algorithm": { + "type": lambda x: None if x == "None" else x, + "default": defaults.GRADIENT_CLIP_ALGORITHM, + "help": "Gradient clipping algorithm to use.", + "choices": ["norm", "value", None], + "argument_group": OPTIONAL_ARGS_GROUP, + }, "prepare_data": { "type": str, "default": "True", diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index bb08f050..a2d18293 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -169,6 +169,8 @@ def run(self): devices=args.devices, precision=args.precision, strategy=args.strategy, + gradient_clip_algorithm=args.gradient_clip_algorithm, + gradient_clip_val=args.gradient_clip_val, early_stopping_enabled=args.early_stopping, deterministic_trainer=args.deterministic_trainer, loss_fn=loss_fn, diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 27eccd04..9a53861d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -33,6 +33,8 @@ VOLUME_SIZE = 60 DISTRIBUTED_STRATEGY = "ddp" PRECISION = "32" +GRADIENT_CLIP_VAL = None +GRADIENT_CLIP_ALGORITHM = None LEARNER = "ER" INSTANCE_COUNT = 1 diff --git a/src/renate/training/training.py b/src/renate/training/training.py index b528fba7..89c0cbb8 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -93,6 +93,8 @@ def run_training_job( strategy: str = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, job_name: str = defaults.JOB_NAME, ) -> Optional[Tuner]: """Starts updating the model including hyperparameter optimization. @@ -179,6 +181,8 @@ def run_training_job( devices=devices, strategy=strategy, precision=precision, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, deterministic_trainer=deterministic_trainer, ) submit_remote_job( @@ -213,6 +217,8 @@ def run_training_job( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, job_name=job_name, ) @@ -527,6 +533,8 @@ def _execute_training_and_tuning_job_locally( deterministic_trainer: bool, strategy: str, precision: str, + gradient_clip_algorithm: Optional[str], + gradient_clip_val: Optional[float], ): """Executes the training job locally. @@ -547,6 +555,8 @@ def _execute_training_and_tuning_job_locally( config_space["strategy"] = strategy config_space["precision"] = precision config_space["deterministic_trainer"] = deterministic_trainer + config_space["gradient_clip_val"] = gradient_clip_val + config_space["gradient_clip_algorithm"] = gradient_clip_algorithm if input_state_url is not None: config_space["input_state_url"] = input_state_url diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index 6f2839a0..68438f61 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -274,6 +274,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "batch_size": batch_size, @@ -306,6 +308,8 @@ def __init__( devices=devices, strategy=strategy, precision=precision, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) @@ -338,6 +342,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "batch_size": batch_size, @@ -369,6 +375,8 @@ def __init__( devices=devices, strategy=strategy, precision=precision, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) @@ -402,6 +410,8 @@ def __init__( strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "batch_size": batch_size, @@ -434,6 +444,8 @@ def __init__( devices=devices, strategy=strategy, precision=precision, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) @@ -466,6 +478,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -497,4 +511,6 @@ def __init__( devices=devices, strategy=strategy, precision=precision, + gradient_clip_val=gradient_clip_val, + gradient_clip_algorithm=gradient_clip_algorithm, ) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index a0d4e331..f9e6da71 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -552,6 +552,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -590,6 +592,8 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) @@ -629,6 +633,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -668,6 +674,8 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) @@ -708,6 +716,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -748,6 +758,8 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) @@ -791,6 +803,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -834,6 +848,8 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) @@ -883,6 +899,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -932,4 +950,6 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) diff --git a/src/renate/updaters/experimental/fine_tuning.py b/src/renate/updaters/experimental/fine_tuning.py index d9139269..f31295dd 100644 --- a/src/renate/updaters/experimental/fine_tuning.py +++ b/src/renate/updaters/experimental/fine_tuning.py @@ -42,6 +42,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "batch_size": batch_size, @@ -73,4 +75,6 @@ def __init__( deterministic_trainer=deterministic_trainer, strategy=strategy, precision=precision, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index 5c953706..d0d69656 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -132,6 +132,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -166,4 +168,6 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index 6ab5e52e..cf907560 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -121,6 +121,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "batch_size": batch_size, @@ -151,4 +153,6 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 2666bd7c..17c3e0e1 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -167,6 +167,8 @@ def __init__( precision: str = defaults.PRECISION, seed: int = defaults.SEED, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): learner_kwargs = { "memory_size": memory_size, @@ -202,4 +204,6 @@ def __init__( strategy=strategy, precision=precision, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, ) diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index 8f4db296..c340abde 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -238,6 +238,9 @@ class ModelUpdater(abc.ABC): The value is passed to the trainer as described `here `_. + gradient_clip_val: Gradient clipping value used in PyTorch Lightning. Defaults to not + clipping by using a value of None. + gradient_clip_algorithm: Method to clip gradients (norm or value) used in PyTorch Lightning. """ def __init__( @@ -268,6 +271,8 @@ def __init__( strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, precision: str = defaults.PRECISION, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, ): self._learner_kwargs = learner_kwargs or {} self._learner_kwargs["loss_fn"] = loss_fn @@ -336,6 +341,8 @@ def __init__( self._logger = logger self._num_epochs_trained = 0 self._deterministic_trainer = deterministic_trainer + self._gradient_clip_algorithm = gradient_clip_algorithm + self._gradient_clip_val = gradient_clip_val @abc.abstractmethod def update( @@ -424,6 +431,8 @@ def _fit_learner( deterministic=self._deterministic_trainer, strategy=strategy, precision=self._precision, + gradient_clip_val=self._gradient_clip_val, + gradient_clip_algorithm=self._gradient_clip_algorithm, ) trainer.fit(learner) self._num_epochs_trained = trainer.current_epoch From 54c37ced0207ea71a8e894ed62a4a6ae82971ada Mon Sep 17 00:00:00 2001 From: Giovanni <52964960+610v4nn1@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:49:34 +0200 Subject: [PATCH 59/89] Add benchmark made of four text datasets (#354) --- src/renate/benchmark/datasets/nlp_datasets.py | 169 ++++++++++++++++++ src/renate/defaults.py | 2 + .../benchmark/datasets/test_multi_data_nlp.py | 65 +++++++ 3 files changed, 236 insertions(+) create mode 100644 test/renate/benchmark/datasets/test_multi_data_nlp.py diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index 906cdaae..5abe1fc8 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -1,13 +1,16 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import functools import logging from typing import Any, Dict, Optional import datasets import torch import transformers +from datasets import load_dataset from renate import defaults +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data.data_module import RenateDataModule @@ -134,3 +137,169 @@ def tokenize_fn(batch): self._val_data = _InputTargetWrapper(self._val_data, self._target_column) else: self._train_data, self._val_data = self._split_train_val_data(self._train_data) + + +class MultiTextDataModule(DataIncrementalDataModule): + """ + Inspired by the dataset used in "Episodic Memory in Lifelong Language Learning" + by d’Autume et al. this is a collection of five different datasets that we call domains: + AGNews, Yelp, DBPedia and Yahoo Answers. + + The output space if the union of the output space of all the domains. + The dataset has 33 classes: 4 from AGNews, 5 from Yelp, 14 from DBPedia, and 10 from Yahoo. + + The maximum allowed size for the training set is 115000 and for the test set is 7600. + Each domain will have the same fixed size. + + Args: + data_path: The path to the folder where the data files will be downloaded to. + tokenizer: Tokenizer to apply to the dataset. See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. + tokenizer_kwargs: Keyword arguments passed when calling the tokenizer's ``__call__`` + function. Typical options are `max_length`, `padding` and `truncation`. + See https://huggingface.co/docs/tokenizers/ + for more information on tokenizers. If `None` is passed, this defaults to + `{"padding": "max_length", max_length: 128, truncation: True}`. + data_id: The dataset to be used + train_size: The size of the data stored as training set, must be smaller than 115000. + test_size: The size of the data stored as test set, must be smaller than 7600. + val_size: Fraction of the training data to be used for validation. + seed: Seed used to fix random number generation. + """ + + _multi_dataset_info = { + "ag_news": ["text", "label"], + "yelp_review_full": ["text", "label"], + "dbpedia_14": ["content", "label"], + "yahoo_answers_topics": ["question_title", "topic"], + } + _labels_map = { + "ag_news0": 0, + "ag_news1": 1, + "ag_news2": 2, + "ag_news3": 3, + "yelp_review_full0": 4, + "yelp_review_full1": 5, + "yelp_review_full2": 6, + "yelp_review_full3": 7, + "yelp_review_full4": 8, + "dbpedia_140": 9, + "dbpedia_141": 10, + "dbpedia_142": 11, + "dbpedia_143": 12, + "dbpedia_144": 13, + "dbpedia_145": 14, + "dbpedia_146": 15, + "dbpedia_147": 16, + "dbpedia_148": 17, + "dbpedia_149": 18, + "dbpedia_1410": 19, + "dbpedia_1411": 20, + "dbpedia_1412": 21, + "dbpedia_1413": 22, + "yahoo_answers_topics0": 23, + "yahoo_answers_topics1": 24, + "yahoo_answers_topics2": 25, + "yahoo_answers_topics3": 26, + "yahoo_answers_topics4": 27, + "yahoo_answers_topics5": 28, + "yahoo_answers_topics6": 29, + "yahoo_answers_topics7": 30, + "yahoo_answers_topics8": 31, + "yahoo_answers_topics9": 32, + } + + domains = _multi_dataset_info.keys() + + def __init__( + self, + data_path: str, + tokenizer: transformers.PreTrainedTokenizer, + data_id: str, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + train_size: int = defaults.SMALL_TRAIN_SET_SIZE, + test_size: int = defaults.SMALL_TEST_SET_SIZE, + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__(data_path=data_path, data_id=data_id, val_size=val_size, seed=seed) + + if train_size > 115000: + raise ValueError("The `train_size` must be smaller than or equal to 115000") + self._train_size = train_size + + if test_size > 7600: + raise ValueError("The `test_size` must be smaller than 7600") + self._test_size = test_size + + self._tokenizer = tokenizer + self._tokenizer_kwargs = tokenizer_kwargs or defaults.TOKENIZER_KWARGS + + if data_id not in self.domains: + raise ValueError( + f"The selected domain is not available. Select one among " f"{self.domains}" + ) + + self.data_id = data_id + + def prepare_data(self) -> None: + """Download dataset.""" + + for split in ["train", "test"] + (["validation"] if self._val_size > 0 else []): + load_dataset(self.data_id, split=split, cache_dir=self._data_path) + + def setup(self) -> None: + """Set up train, test and val datasets.""" + + rnd_gen = torch.Generator().manual_seed(self._seed) + + def preprocess(example, text_field_name, label_field_name): + return { + **self._tokenizer(example[text_field_name], **self._tokenizer_kwargs), + "label": self._labels_map[f"{self.data_id}{example[label_field_name]}"], + } + + def get_split(split_name): + dataset = load_dataset(self.data_id, split=split_name, cache_dir=self._data_path) + # the following is hack needed because the output space of the new dataset is + # the union of the output spaces of the single datasets + # HF datasets check for the max label id and we need to make sure we update that + # without this change the setup will fail with a value error (label id > max labels) + new_features = dataset.features.copy() + new_features[self._multi_dataset_info[self.data_id][1]] = datasets.ClassLabel( + num_classes=33 + ) + + dataset = dataset.cast(new_features) + + if "train" == split_name: + set_size = self._train_size + else: + set_size = self._test_size + + rnd_idx = torch.randint( + low=0, + high=len(dataset), + size=(set_size,), + generator=rnd_gen, + ).tolist() + dataset = dataset.select(indices=rnd_idx) + + dataset = dataset.map( + functools.partial( + preprocess, + text_field_name=self._multi_dataset_info[self.data_id][0], + label_field_name=self._multi_dataset_info[self.data_id][1], + ), + remove_columns=list(dataset.features), + num_proc=4, + ) + + dataset.set_format(type="torch") + + return _InputTargetWrapper(dataset) + + self._train_data = get_split("train") + self._test_data = get_split("test") + if self._val_size > 0: + self._val_data = get_split("validation") diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 9a53861d..c41d7499 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -106,6 +106,8 @@ # Benchmark datasets/models TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True} +SMALL_TRAIN_SET_SIZE = 1000 +SMALL_TEST_SET_SIZE = 1000 def scheduler(config_space: Dict[str, Any], mode: str, metric: str): diff --git a/test/renate/benchmark/datasets/test_multi_data_nlp.py b/test/renate/benchmark/datasets/test_multi_data_nlp.py new file mode 100644 index 00000000..f91765c2 --- /dev/null +++ b/test/renate/benchmark/datasets/test_multi_data_nlp.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch +import transformers as transformers + +from renate.benchmark.datasets.nlp_datasets import MultiTextDataModule + + +@pytest.mark.skip(reason="This test create problems with the syne-tune redirect test") +def test_multi_data_nlp_small(tmpdir): + TRAIN_SIZE = 100 + TEST_SIZE = 100 + + data = MultiTextDataModule( + tmpdir, + train_size=TRAIN_SIZE, + test_size=TEST_SIZE, + data_id="ag_news", + tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"), + seed=42, + ) + + data.prepare_data() + data.setup() + + assert len(data.train_data()) == TRAIN_SIZE + assert len(data.test_data()) == TEST_SIZE + + first_input_agnews = data.train_data()[0][0]["input_ids"] + + data.data_id = "dbpedia_14" + data.setup() + + tr_data_dbpedia = data.train_data() + te_data_dbpedia = data.test_data() + assert len(tr_data_dbpedia) == TRAIN_SIZE + assert len(te_data_dbpedia) == TEST_SIZE + + first_input_dbpedia = data.train_data()[0][0]["input_ids"] + + assert not torch.all(torch.eq(first_input_dbpedia, first_input_agnews)) + + +@pytest.mark.skip(reason="This test requires downloading and processing four datasets.") +def test_multi_data_nlp_full(tmpdir): + TRAIN_SIZE = 115000 + TEST_SIZE = 7600 + + for d in MultiTextDataModule.domains: + data = MultiTextDataModule( + tmpdir, + train_size=TRAIN_SIZE, + test_size=TEST_SIZE, + data_id=d, + tokenizer=transformers.DistilBertTokenizer.from_pretrained("distilbert-base-uncased"), + ) + + data.prepare_data() + data.setup() + + tr_data = data.train_data() + te_data = data.test_data() + assert len(tr_data) == TRAIN_SIZE + assert len(te_data) == TEST_SIZE From 445ebbb1d4370a3d9fd141d5b68b94d525b4a69d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:57:40 +0200 Subject: [PATCH 60/89] Bump sphinx from 6.1.3 to 7.1.0 (#352) --- doc/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 18dc8184..77449f40 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,5 +1,5 @@ docutils==0.19 -Sphinx==6.1.3 +Sphinx==7.1.0 sphinx-copybutton==0.5.2 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.2 From 1538e30dde08028d2dd8508f08491383dd507fd4 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:59:08 +0200 Subject: [PATCH 61/89] Update scipy requirement from <1.10.2,>=1.9.0 to >=1.9.0,<1.11.2 (#324) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 4c011b0e..a969afd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 datasets>=2.9.0, <2.12.1 transformers>=4.30.0, <4.30.2 -scipy>=1.9.0, <1.10.2 +scipy>=1.9.0, <1.11.2 From 448fdb4d855201bdfb6df427ced01ce18f4b3f78 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 3 Aug 2023 18:59:56 +0200 Subject: [PATCH 62/89] Bump pytest from 7.3.1 to 7.4.0 (#321) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fd7864ca..249b7c43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dev = [ "wild-time-data==0.1.1", "torch>=1.10.0, <1.12.2", # PyTest Dependencies - "pytest==7.3.1", + "pytest==7.4.0", "pytest-cov==4.1.0", "pytest-helpers-namespace==2021.12.29", ] From e456d243e628969a9232f7126ec6511824854126 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Aug 2023 11:09:01 +0200 Subject: [PATCH 63/89] Update datasets requirement from <2.12.1,>=2.9.0 to >=2.9.0,<2.14.1 (#348) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a969afd9..c6e7631b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,6 @@ tensorboardX>=2.5.0, <2.5.2 torchmetrics>=0.11.0, <0.11.5 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 -datasets>=2.9.0, <2.12.1 +datasets>=2.9.0, <2.14.1 transformers>=4.30.0, <4.30.2 scipy>=1.9.0, <1.11.2 From 702d01aad3cebe0e8599b20ef31075d9344f98c3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Aug 2023 11:12:12 +0200 Subject: [PATCH 64/89] Update tensorboardx requirement from <2.5.2,>=2.5.0 to >=2.5.0,<2.6.2 (#308) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c6e7631b..24cddbb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ syne-tune[aws,gpsearchers]==0.6.0 pytorch-lightning>=1.8.0, <1.9.5 Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 -tensorboardX>=2.5.0, <2.5.2 +tensorboardX>=2.5.0, <2.6.2 torchmetrics>=0.11.0, <0.11.5 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 From d12c3c1f4897416e113163bf08921aa56bed1082 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:51:59 +0200 Subject: [PATCH 65/89] Update numpy requirement from <1.24.4,>=1.17.2 to >=1.17.2,<1.25.2 (#335) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 24cddbb3..e5dc6b94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.17.2, <1.24.4 +numpy>=1.17.2, <1.25.2 torch>=1.10.0, <1.13.2 pandas>=1.4.0, <2.0.3 boto3>=1.26.0, <1.26.139 From c0612c0b44ca9cb75034e55013f0398d01005e89 Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 4 Aug 2023 19:19:18 +0200 Subject: [PATCH 66/89] Support Use of Joint and GDumb with Pre-Trained Models (#362) --- src/renate/cli/parsing_functions.py | 13 +++++- src/renate/cli/run_training.py | 2 +- src/renate/updaters/experimental/gdumb.py | 2 - src/renate/updaters/experimental/joint.py | 2 - .../configs/suites/quick/gdumb.json | 4 +- .../configs/suites/quick/joint.json | 4 +- .../updaters/experimental/test_joint.py | 42 ++++--------------- 7 files changed, 25 insertions(+), 44 deletions(-) diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 995e7d22..cb79af5c 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -454,11 +454,22 @@ def _add_base_experience_replay_arguments(arguments: Dict[str, Dict[str, Any]]) def _add_gdumb_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: """A helper function that adds GDumb arguments.""" _add_replay_learner_arguments(arguments) + _add_joint_arguments(arguments) def _add_joint_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: """A helper function that adds Joint Learner arguments.""" - pass + arguments.update( + { + "reset": { + "type": str, + "default": "True", + "choices": ["True", "False"], + "help": "Resets the model before the update. Default: True", + "true_type": bool, + }, + } + ) def _add_finetuning_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: diff --git a/src/renate/cli/run_training.py b/src/renate/cli/run_training.py index a2d18293..6fda7d17 100644 --- a/src/renate/cli/run_training.py +++ b/src/renate/cli/run_training.py @@ -122,7 +122,7 @@ def run(self): ) model = get_model( config_module, - model_state_url=self._current_model_file, + model_state_url=None if getattr(args, "reset", False) else self._current_model_file, **get_function_kwargs(args=args, function_args=function_args["model_fn"]), ) loss_fn = get_loss_fn( diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index d0d69656..b7cee0b6 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -17,7 +17,6 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater -from renate.utils.pytorch import reinitialize_model_parameters class GDumbLearner(ReplayLearner): @@ -79,7 +78,6 @@ def on_model_update_start( task_id=task_id, ) self._memory_buffer.update(train_dataset) - reinitialize_model_parameters(self._model) def train_dataloader(self) -> DataLoader: return DataLoader( diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index cf907560..22d38bd6 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -18,7 +18,6 @@ from renate.types import NestedTensors from renate.updaters.learner import Learner from renate.updaters.model_updater import SingleTrainingLoopUpdater -from renate.utils.pytorch import reinitialize_model_parameters class JointLearner(Learner): @@ -72,7 +71,6 @@ def on_model_update_start( task_id=task_id, ) self._memory_buffer.update(train_dataset) - reinitialize_model_parameters(self._model) def train_dataloader(self) -> DataLoader: return DataLoader( diff --git a/test/integration_tests/configs/suites/quick/gdumb.json b/test/integration_tests/configs/suites/quick/gdumb.json index 04631cd6..abb8e597 100644 --- a/test/integration_tests/configs/suites/quick/gdumb.json +++ b/test/integration_tests/configs/suites/quick/gdumb.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-gdumb", - "expected_accuracy_linux": [[0.8364999890327454, 0.8634999990463257]], - "expected_accuracy_darwin": [[0.8364999890327454, 0.8634999990463257]] + "expected_accuracy_linux": [[0.43050000071525574, 0.8069999814033508]], + "expected_accuracy_darwin": [[0.43050000071525574, 0.8069999814033508]] } diff --git a/test/integration_tests/configs/suites/quick/joint.json b/test/integration_tests/configs/suites/quick/joint.json index 2de798c4..f6decc97 100644 --- a/test/integration_tests/configs/suites/quick/joint.json +++ b/test/integration_tests/configs/suites/quick/joint.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "iid-mlp-joint", - "expected_accuracy_linux": [[0.8495000004768372, 0.8495000004768372], [0.8427000045776367, 0.8427000045776367]], - "expected_accuracy_darwin": [[0.84170001745224, 0.84170001745224]] + "expected_accuracy_linux": [[0.8496000170707703, 0.8496000170707703], [0.8550000190734863, 0.8550000190734863]], + "expected_accuracy_darwin": [[0.8585000038146973, 0.8585000038146973]] } diff --git a/test/renate/updaters/experimental/test_joint.py b/test/renate/updaters/experimental/test_joint.py index fb0e62e0..a299abd5 100644 --- a/test/renate/updaters/experimental/test_joint.py +++ b/test/renate/updaters/experimental/test_joint.py @@ -1,28 +1,22 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import copy import pytest -import torch import renate.defaults as defaults from renate.updaters.experimental.joint import JointLearner -def get_model_and_dataset(): - model = pytest.helpers.get_renate_module_mlp( - num_outputs=10, num_inputs=10, hidden_size=32, num_hidden_layers=3 - ) - dataset = torch.utils.data.TensorDataset( - torch.rand((100, 10)), - torch.randint(10, (100,)), - ) - return model, dataset - - def test_joint_learner_memory_append(): """This test checks that the memory buffer is updated correctly.""" - model, dataset = get_model_and_dataset() + model, dataset, _ = pytest.helpers.get_renate_module_mlp_and_data( + num_inputs=10, + num_outputs=10, + num_hidden_layers=3, + hidden_size=32, + train_num_samples=100, + test_num_samples=100, + ) dataset_len = len(dataset) model_updater = pytest.helpers.get_simple_updater( model=model, @@ -35,23 +29,3 @@ def test_joint_learner_memory_append(): assert len(model_updater._learner._memory_buffer) == dataset_len model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) assert len(model_updater._learner._memory_buffer) == 2 * dataset_len - - -def test_joint_learner_model_reset(): - """This test checks that the model is reinitialized correctly before an update.""" - model, dataset = get_model_and_dataset() - model_updater = pytest.helpers.get_simple_updater( - model=model, - partial_optimizer=pytest.helpers.get_partial_optimizer(lr=0.0), - learner_class=JointLearner, - learner_kwargs={}, - max_epochs=1, - ) - model = model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) - model_copy = copy.deepcopy(model) - model = model_updater.update(train_dataset=dataset, task_id=defaults.TASK_ID) - for (name, param), (name_copy, param_copy) in zip( - model.named_parameters(), model_copy.named_parameters() - ): - assert name == name_copy - assert not torch.allclose(param, param_copy) From 194b62e7273f5d5a0b60a2e1516ea63fe228f280 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 7 Aug 2023 15:26:50 +0200 Subject: [PATCH 67/89] Update pandas requirement from <2.0.3,>=1.4.0 to >=1.4.0,<2.0.4 (#326) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e5dc6b94..a8121e0e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy>=1.17.2, <1.25.2 torch>=1.10.0, <1.13.2 -pandas>=1.4.0, <2.0.3 +pandas>=1.4.0, <2.0.4 boto3>=1.26.0, <1.26.139 requests>=2.31.0, <2.31.1 sagemaker>=2.112.0, <2.158.1 From 38ea8a7d49ec9c4f3e4caf2e35e15ff8bf0f1889 Mon Sep 17 00:00:00 2001 From: wistuba Date: Tue, 8 Aug 2023 09:00:03 +0200 Subject: [PATCH 68/89] Wild Time Benchmarks and Small Memory Hack (#363) --- .../experiment_configs/datasets/arxiv.json | 6 ++++++ .../experiment_configs/datasets/fmow.json | 6 ++---- .../experiment_configs/datasets/huffpost.json | 7 ++----- .../experiment_configs/fine-tuning-arxiv.json | 6 ++++++ .../experiment_configs/fine-tuning-fmow.json | 2 +- .../fine-tuning-huffpost.json | 2 +- .../experiment_configs/joint-arxiv.json | 6 ++++++ benchmarks/experiment_configs/joint-fmow.json | 6 ++++++ .../experiment_configs/joint-huffpost.json | 6 ++++++ .../scenarios/arxiv-16updates.json | 9 +++++++++ .../scenarios/clear10-10updates.json | 1 - .../scenarios/clear100-11updates.json | 1 - .../scenarios/fmow-16updates.json | 14 +++++++++++++ .../scenarios/huffpost-7updates.json | 9 +++++++++ .../scenarios/wild-time.json | 4 ---- .../updaters/fine-tuning-arxiv.json | 4 ++++ .../updaters/fine-tuning-fmow.json | 9 +-------- .../updaters/fine-tuning-huffpost.json | 6 +----- .../updaters/joint-arxiv.json | 4 ++++ .../updaters/joint-fmow.json | 4 ++++ .../updaters/joint-huffpost.json | 4 ++++ benchmarks/run_benchmark.py | 1 - .../benchmark/datasets/vision_datasets.py | 12 +++++------ .../benchmark/datasets/wild_time_data.py | 1 + src/renate/benchmark/experiment_config.py | 20 ++++++++++++++++++- .../benchmark/models/vision_transformer.py | 4 ++-- .../benchmark/test_experimentation_config.py | 18 ++++++++++++++++- 27 files changed, 131 insertions(+), 41 deletions(-) create mode 100644 benchmarks/experiment_configs/datasets/arxiv.json create mode 100644 benchmarks/experiment_configs/fine-tuning-arxiv.json create mode 100644 benchmarks/experiment_configs/joint-arxiv.json create mode 100644 benchmarks/experiment_configs/joint-fmow.json create mode 100644 benchmarks/experiment_configs/joint-huffpost.json create mode 100644 benchmarks/experiment_configs/scenarios/arxiv-16updates.json create mode 100644 benchmarks/experiment_configs/scenarios/fmow-16updates.json create mode 100644 benchmarks/experiment_configs/scenarios/huffpost-7updates.json delete mode 100644 benchmarks/experiment_configs/scenarios/wild-time.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-arxiv.json create mode 100644 benchmarks/experiment_configs/updaters/joint-arxiv.json create mode 100644 benchmarks/experiment_configs/updaters/joint-fmow.json create mode 100644 benchmarks/experiment_configs/updaters/joint-huffpost.json diff --git a/benchmarks/experiment_configs/datasets/arxiv.json b/benchmarks/experiment_configs/datasets/arxiv.json new file mode 100644 index 00000000..6af20941 --- /dev/null +++ b/benchmarks/experiment_configs/datasets/arxiv.json @@ -0,0 +1,6 @@ +{ + "dataset_name": "arxiv", + "src_bucket": "my_bucket", + "src_object_name": "dataset/wildtime/arxiv.hdf5", + "num_outputs": 172 +} diff --git a/benchmarks/experiment_configs/datasets/fmow.json b/benchmarks/experiment_configs/datasets/fmow.json index eb36763b..cac666fe 100644 --- a/benchmarks/experiment_configs/datasets/fmow.json +++ b/benchmarks/experiment_configs/datasets/fmow.json @@ -1,9 +1,7 @@ { "dataset_name": "fmow", - "src_bucket": "mnemosyne-team-bucket", + "src_bucket": "my_bucket", "src_object_name": "dataset/wildtime/fmow.hdf5", "num_inputs": 150528, - "num_outputs": 62, - "num_tasks": 16, - "max_epochs": 50 + "num_outputs": 62 } diff --git a/benchmarks/experiment_configs/datasets/huffpost.json b/benchmarks/experiment_configs/datasets/huffpost.json index af65f008..2b4fcf4f 100644 --- a/benchmarks/experiment_configs/datasets/huffpost.json +++ b/benchmarks/experiment_configs/datasets/huffpost.json @@ -1,9 +1,6 @@ { "dataset_name": "huffpost", - "src_bucket": "mnemosyne-team-bucket", + "src_bucket": "my_bucket", "src_object_name": "dataset/wildtime/huffpost.hdf5", - "num_inputs": 0, - "num_outputs": 11, - "num_tasks": 7, - "max_epochs": 5 + "num_outputs": 11 } diff --git a/benchmarks/experiment_configs/fine-tuning-arxiv.json b/benchmarks/experiment_configs/fine-tuning-arxiv.json new file mode 100644 index 00000000..54e12550 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-arxiv.json @@ -0,0 +1,6 @@ +{ + "scenario": "arxiv-16updates.json", + "model": "bert.json", + "updater": "fine-tuning-arxiv.json", + "dataset": "arxiv.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-fmow.json b/benchmarks/experiment_configs/fine-tuning-fmow.json index 060b25d4..7554f589 100644 --- a/benchmarks/experiment_configs/fine-tuning-fmow.json +++ b/benchmarks/experiment_configs/fine-tuning-fmow.json @@ -1,5 +1,5 @@ { - "scenario": "wild-time.json", + "scenario": "fmow-16updates.json", "model": "resnet18.json", "updater": "fine-tuning-fmow.json", "dataset": "fmow.json" diff --git a/benchmarks/experiment_configs/fine-tuning-huffpost.json b/benchmarks/experiment_configs/fine-tuning-huffpost.json index b3478c77..f5af8d10 100644 --- a/benchmarks/experiment_configs/fine-tuning-huffpost.json +++ b/benchmarks/experiment_configs/fine-tuning-huffpost.json @@ -1,5 +1,5 @@ { - "scenario": "wild-time.json", + "scenario": "huffpost-7updates.json", "model": "bert.json", "updater": "fine-tuning-huffpost.json", "dataset": "huffpost.json" diff --git a/benchmarks/experiment_configs/joint-arxiv.json b/benchmarks/experiment_configs/joint-arxiv.json new file mode 100644 index 00000000..6f0b64c3 --- /dev/null +++ b/benchmarks/experiment_configs/joint-arxiv.json @@ -0,0 +1,6 @@ +{ + "scenario": "arxiv-16updates.json", + "model": "bert.json", + "updater": "joint-arxiv.json", + "dataset": "arxiv.json" +} diff --git a/benchmarks/experiment_configs/joint-fmow.json b/benchmarks/experiment_configs/joint-fmow.json new file mode 100644 index 00000000..ee849987 --- /dev/null +++ b/benchmarks/experiment_configs/joint-fmow.json @@ -0,0 +1,6 @@ +{ + "scenario": "fmow-16updates.json", + "model": "resnet18.json", + "updater": "joint-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/joint-huffpost.json b/benchmarks/experiment_configs/joint-huffpost.json new file mode 100644 index 00000000..d05cd544 --- /dev/null +++ b/benchmarks/experiment_configs/joint-huffpost.json @@ -0,0 +1,6 @@ +{ + "scenario": "huffpost-7updates.json", + "model": "bert.json", + "updater": "joint-huffpost.json", + "dataset": "huffpost.json" +} diff --git a/benchmarks/experiment_configs/scenarios/arxiv-16updates.json b/benchmarks/experiment_configs/scenarios/arxiv-16updates.json new file mode 100644 index 00000000..67ceba52 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/arxiv-16updates.json @@ -0,0 +1,9 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 16, + "max_epochs": 2, + "optimizer": "AdamW", + "learning_rate": 0.00002, + "weight_decay": 0.01 +} diff --git a/benchmarks/experiment_configs/scenarios/clear10-10updates.json b/benchmarks/experiment_configs/scenarios/clear10-10updates.json index 46c52936..f22e0070 100644 --- a/benchmarks/experiment_configs/scenarios/clear10-10updates.json +++ b/benchmarks/experiment_configs/scenarios/clear10-10updates.json @@ -7,7 +7,6 @@ "learning_rate": 0.01, "momentum": 0.9, "weight_decay": 1e-5, - "batch_size": 256, "learning_rate_scheduler": "StepLR", "learning_rate_scheduler_step_size": 30, "learning_rate_scheduler_gamma": 0.1, diff --git a/benchmarks/experiment_configs/scenarios/clear100-11updates.json b/benchmarks/experiment_configs/scenarios/clear100-11updates.json index 8fafd4c4..2df13ae6 100644 --- a/benchmarks/experiment_configs/scenarios/clear100-11updates.json +++ b/benchmarks/experiment_configs/scenarios/clear100-11updates.json @@ -7,7 +7,6 @@ "learning_rate": 0.01, "momentum": 0.9, "weight_decay": 1e-5, - "batch_size": 256, "learning_rate_scheduler": "StepLR", "learning_rate_scheduler_step_size": 30, "learning_rate_scheduler_gamma": 0.1, diff --git a/benchmarks/experiment_configs/scenarios/fmow-16updates.json b/benchmarks/experiment_configs/scenarios/fmow-16updates.json new file mode 100644 index 00000000..e9bbfabf --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/fmow-16updates.json @@ -0,0 +1,14 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 16, + "max_epochs": 50, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 50, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/scenarios/huffpost-7updates.json b/benchmarks/experiment_configs/scenarios/huffpost-7updates.json new file mode 100644 index 00000000..121782d6 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/huffpost-7updates.json @@ -0,0 +1,9 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 7, + "max_epochs": 4, + "optimizer": "AdamW", + "learning_rate": 0.00002, + "weight_decay": 0.01 +} diff --git a/benchmarks/experiment_configs/scenarios/wild-time.json b/benchmarks/experiment_configs/scenarios/wild-time.json deleted file mode 100644 index ec2b116b..00000000 --- a/benchmarks/experiment_configs/scenarios/wild-time.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "val_size": 0.1, - "scenario_name": "DataIncrementalScenario" -} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-arxiv.json b/benchmarks/experiment_configs/updaters/fine-tuning-arxiv.json new file mode 100644 index 00000000..09e7a594 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-arxiv.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json b/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json index 77040f12..09e7a594 100644 --- a/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json +++ b/benchmarks/experiment_configs/updaters/fine-tuning-fmow.json @@ -1,11 +1,4 @@ { "updater": "FineTuning", - "optimizer": "SGD", - "learning_rate": 0.03, - "learning_rate_scheduler": "CosineAnnealingLR", - "learning_rate_scheduler_t_max": 50, - "learning_rate_scheduler_eta_min": 0.0001, - "learning_rate_scheduler_interval": "step", - "momentum": 0.0, - "weight_decay": 0.0 + "batch_size": 64 } diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json b/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json index bdb46bc2..7184ed85 100644 --- a/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json +++ b/benchmarks/experiment_configs/updaters/fine-tuning-huffpost.json @@ -1,8 +1,4 @@ { "updater": "FineTuning", - "optimizer": "Adam", - "learning_rate": 0.0001, - "momentum": 0.9, - "weight_decay": 0.0, - "batch_size": 64 + "batch_size": 32 } diff --git a/benchmarks/experiment_configs/updaters/joint-arxiv.json b/benchmarks/experiment_configs/updaters/joint-arxiv.json new file mode 100644 index 00000000..0cfc3ad2 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-arxiv.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-fmow.json b/benchmarks/experiment_configs/updaters/joint-fmow.json new file mode 100644 index 00000000..0cfc3ad2 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-fmow.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-huffpost.json b/benchmarks/experiment_configs/updaters/joint-huffpost.json new file mode 100644 index 00000000..3ec09677 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-huffpost.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 32 +} diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index deb80e8f..953e0819 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -91,7 +91,6 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): experiment_outputs_url = ( Path("tmp") / "renate-integration-tests" / args.job_name / str(seed) ) - role = None working_directory = str(Path("tmp") / "renate_working_dir") else: AWS_ACCOUNT_ID = boto3.client("sts").get_caller_identity().get("Account") diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 11d5c90b..abd33af6 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -68,10 +68,10 @@ def prepare_data(self) -> None: def setup(self) -> None: """Set up train, test and val datasets.""" X, y = self._preprocess_tiny_imagenet(train=True) - train_data = ImageDataset(X, y, transform=transforms.ToTensor()) + train_data = ImageDataset(X, y) self._train_data, self._val_data = self._split_train_val_data(train_data) X, y = self._preprocess_tiny_imagenet(train=False) - self._test_data = ImageDataset(X, y, transform=transforms.ToTensor()) + self._test_data = ImageDataset(X, y) def _preprocess_tiny_imagenet(self, train: bool) -> Tuple[List[str], List[int]]: """A helper function to preprocess the TinyImageNet dataset.""" @@ -259,10 +259,10 @@ def setup(self) -> None: """Set up train, test and val datasets.""" time_step = self.data_id + 1 if self._dataset_name == "clear10" else self.data_id X, y = self._get_filepaths_and_labels(train=True, time_step=time_step) - train_data = ImageDataset(X, y, transform=transforms.ToTensor()) + train_data = ImageDataset(X, y) self._train_data, self._val_data = self._split_train_val_data(train_data) X, y = self._get_filepaths_and_labels(train=False, time_step=time_step) - self._test_data = ImageDataset(X, y, transform=transforms.ToTensor()) + self._test_data = ImageDataset(X, y) def _get_filepaths_and_labels(self, train: bool, time_step: int) -> Tuple[List[str], List[int]]: """Extracts all the filepaths and labels for a given chunk id and split.""" @@ -384,10 +384,10 @@ def prepare_data(self) -> None: def setup(self) -> None: """Set up train, test and val datasets.""" X, y = self._get_filepaths_and_labels("train") - train_data = ImageDataset(X, y, transform=transforms.ToTensor()) + train_data = ImageDataset(X, y) self._train_data, self._val_data = self._split_train_val_data(train_data) X, y = self._get_filepaths_and_labels("test") - self._test_data = ImageDataset(X, y, transform=transforms.ToTensor()) + self._test_data = ImageDataset(X, y) def _get_filepaths_and_labels(self, split: str) -> Tuple[List[str], List[int]]: """Extracts all the filepaths and labels for a given split.""" diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py index 0e9056d7..b8260222 100644 --- a/src/renate/benchmark/datasets/wild_time_data.py +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -90,6 +90,7 @@ def setup(self) -> None: "time_step": available_time_steps(self._dataset_name)[self.data_id], "data_dir": self._data_path, "in_memory": self._dataset_name != "fmow", + "transform": None if self._dataset_name != "fmow" else lambda x: x, } if self._tokenizer: kwargs["transform"] = lambda x: self._tokenizer(x, **(self._tokenizer_kwargs or {})) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 74f5984f..7516579c 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -5,11 +5,13 @@ import torch import wild_time_data -from torch.optim import Optimizer +from torch.optim import AdamW, Optimizer from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, _LRScheduler from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import transforms from transformers import AutoTokenizer +from wild_time_data import default_transform +from wild_time_data.datasets import FMoW from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.benchmark.datasets.vision_datasets import ( @@ -317,6 +319,8 @@ def _get_normalize_transform(dataset_name): def train_transform(dataset_name: str) -> Optional[Callable]: """Returns a transform function to be used in the training.""" + if dataset_name == "fmow": + return default_transform(dataset_name) if dataset_name in [ "MNIST", "FashionMNIST", @@ -333,6 +337,7 @@ def train_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ + transforms.ToTensor(), transforms.Resize( 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True ), @@ -345,6 +350,8 @@ def train_transform(dataset_name: str) -> Optional[Callable]: def test_transform(dataset_name: str) -> Optional[Callable]: """Returns a transform function to be used for validation or testing.""" + if dataset_name == "fmow": + return FMoW.default_transform if dataset_name in [ "MNIST", "FashionMNIST", @@ -355,6 +362,7 @@ def test_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ + transforms.ToTensor(), transforms.Resize( 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True ), @@ -398,3 +406,13 @@ def lr_scheduler_fn( def metrics_fn(num_outputs: int) -> Dict: return {"accuracy": MulticlassAccuracy(num_classes=num_outputs, average="micro")} + + +def optimizer_fn( + optimizer: str, + learning_rate: float, + weight_decay: float, + momentum: float = 0.0, # TODO: fix problem that occurs when removing this +) -> Callable: + if optimizer == "AdamW": + return partial(AdamW, lr=learning_rate, weight_decay=weight_decay) diff --git a/src/renate/benchmark/models/vision_transformer.py b/src/renate/benchmark/models/vision_transformer.py index 4ad4aa19..4d0e8dce 100644 --- a/src/renate/benchmark/models/vision_transformer.py +++ b/src/renate/benchmark/models/vision_transformer.py @@ -50,8 +50,8 @@ class VisionTransformer(RenateBenchmarkingModule): arXiv preprint arXiv:2010.11929 (2020). Args: - pretrained_name: A string that denotes which pretrained model from the HF hub to use. - If provided, it overrides other arguments about architecture. + pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub + to use. If provided, it overrides other arguments about architecture. image_size: Size of the input image. patch_size: Size of the patches. num_layers: Number of Encoder layers. diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 659fb042..512c40d0 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import pytest from torch.nn import Linear -from torch.optim import SGD +from torch.optim import AdamW, SGD from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR from torchmetrics.classification import MulticlassAccuracy from torchvision.transforms import Compose, Normalize @@ -19,6 +19,7 @@ metrics_fn, model_fn, models, + optimizer_fn, train_transform, ) from renate.benchmark.scenarios import ( @@ -281,6 +282,7 @@ def test_data_module_fn( ("CLEAR10", True, True), ("DomainNet", True, True), ("hfd-rotten_tomatoes", False, False), + ("fmow", True, True), ), ) def test_transforms(dataset_name, use_transforms, test_compose): @@ -363,3 +365,17 @@ def test_loss_fn_returns_correct_reduction_type(): def test_metrics_fn_contains_accuracy(): assert isinstance(metrics_fn(num_outputs=2)["accuracy"], MulticlassAccuracy) assert isinstance(metrics_fn(num_outputs=10)["accuracy"], MulticlassAccuracy) + + +def test_optimizer_fn(): + expected_learning_rate = 0.12 + partial_optimizer = optimizer_fn( + optimizer="AdamW", learning_rate=expected_learning_rate, weight_decay=0.1 + ) + optimizer: AdamW = partial_optimizer(Linear(10, 10).parameters()) + assert isinstance(optimizer, AdamW) + assert optimizer.defaults["lr"] == expected_learning_rate + + +def test_optimizer_fn_unknown_optimizer(): + assert optimizer_fn(optimizer="UNKNOWN_OPTIMIZER", learning_rate=0.1, weight_decay=1) is None From b27fb01ff2041b6036a1589505f4483e6ff6e0a4 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 10 Aug 2023 15:29:39 +0200 Subject: [PATCH 69/89] Clean Up Learner Checkpoint and Fix Model Loading (#365) Signed-off-by: wistuba --- src/renate/updaters/model_updater.py | 28 +++++++++++-------- .../configs/suites/quick/gdumb.json | 4 +-- .../configs/suites/quick/joint.json | 4 +-- test/renate/updaters/test_model_updater.py | 28 +++++++++++++++++-- 4 files changed, 46 insertions(+), 18 deletions(-) diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index c340abde..6d35677e 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -147,11 +147,10 @@ def _load_best_checkpoint_and_save(self, trainer: Trainer, pl_module: LightningM # Save the buffer only on rank zero. pl_module.save(self._output_state_folder) # Overwrite checkpoint. - self._save_checkpoint(trainer, learner_state_path) + self._save_checkpoint(trainer, str(learner_state_path)) def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - """ - teardown implements the separation of learner and model at the end of training. + """Implements the separation of learner and model at the end of training. There are two cases two handle. @@ -161,20 +160,20 @@ def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> deepspeed stage is used. There are three steps here a. combine all the shards into one big state dict. - b. The learner_state_path is a dir (learner.cpkt/). This needs to be deleted first. - c. Write the combined state_dict as the learner.cpkt file as a single file. - d. Extract the state_dict element from the learner and save that as the model.cpkt. + b. The learner_state_path is a dir (learner.ckpt/). This needs to be deleted first. + c. Write the combined state_dict as the learner.ckpt file as a single file. + d. Extract the state_dict element from the learner and save that as the model.ckpt. 2. If not deepspeed (say DDP or single device): The steps are much simpler. - a. Load the learner.cpkt and extract the state_dict element. + a. Load the learner.ckpt and extract the state_dict element. b. | Sanitize the extracted state_dict. Learner has the model in a _model attribute. | So strip the first "_model." from the keys of the state_dict. - c. Save the sanitized model to model.cpkt. + c. Save the sanitized model to model.ckpt. Case 2 is needs to be done even for Case 1 (step d). So teardown is a recursive call in - Case 1 which automatically goes to Case 2 as learner.cpkt is file now. + Case 1 which automatically goes to Case 2 as learner.ckpt is file now. """ if trainer.is_global_zero and (stage == "fit"): @@ -187,10 +186,14 @@ def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> self.teardown(trainer, pl_module, stage) elif learner_state_path.exists() and learner_state_path.is_file(): # This a normal file. We strip the model of any wrappers and save that. - state_dict = torch.load(learner_state_path)["state_dict"] - out_sd = {k.replace("_model.", "", 1): v for k, v in state_dict.items()} - # Replace only 1 instance because we have to load it into RenateModule. + learner_state = torch.load(learner_state_path) + out_sd = { + k.replace("_model.", "", 1): v for k, v in learner_state["state_dict"].items() + } # Replace only 1 instance because we have to load it into RenateModule. torch.save(out_sd, defaults.model_file(self.dirpath)) + # Remove model from learner checkpoint + learner_state["state_dict"] = {} + torch.save(learner_state, learner_state_path) def on_exception( self, trainer: Trainer, pl_module: LightningModule, exception: BaseException @@ -382,6 +385,7 @@ def _load_learner( self._learner_state_file, model=self._model, logged_metrics=self._logged_metrics, + strict=False, **self._transforms_kwargs, **learner_kwargs, ) diff --git a/test/integration_tests/configs/suites/quick/gdumb.json b/test/integration_tests/configs/suites/quick/gdumb.json index abb8e597..337bd0de 100644 --- a/test/integration_tests/configs/suites/quick/gdumb.json +++ b/test/integration_tests/configs/suites/quick/gdumb.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "class-incremental-mlp-gdumb", - "expected_accuracy_linux": [[0.43050000071525574, 0.8069999814033508]], - "expected_accuracy_darwin": [[0.43050000071525574, 0.8069999814033508]] + "expected_accuracy_linux": [[0.7384999990463257, 0.9304999709129333]], + "expected_accuracy_darwin": [[0.7384999990463257, 0.9304999709129333]] } diff --git a/test/integration_tests/configs/suites/quick/joint.json b/test/integration_tests/configs/suites/quick/joint.json index f6decc97..b9cea200 100644 --- a/test/integration_tests/configs/suites/quick/joint.json +++ b/test/integration_tests/configs/suites/quick/joint.json @@ -5,6 +5,6 @@ "dataset": "fashionmnist.json", "backend": "local", "job_name": "iid-mlp-joint", - "expected_accuracy_linux": [[0.8496000170707703, 0.8496000170707703], [0.8550000190734863, 0.8550000190734863]], - "expected_accuracy_darwin": [[0.8585000038146973, 0.8585000038146973]] + "expected_accuracy_linux": [[0.8425999879837036, 0.8425999879837036], [0.8446999788284302, 0.8446999788284302]], + "expected_accuracy_darwin": [[0.8432000279426575, 0.8432000279426575]] } diff --git a/test/renate/updaters/test_model_updater.py b/test/renate/updaters/test_model_updater.py index 8d76428d..49e9d713 100644 --- a/test/renate/updaters/test_model_updater.py +++ b/test/renate/updaters/test_model_updater.py @@ -34,9 +34,33 @@ def test_simple_model_updater(tmpdir, provide_folder): assert not torch.allclose(y_hat_before_train, y_hat_after_train) +def test_model_passed_is_used_as_is(tmpdir): + """Makes sure that the model passed to the updater is not overwritten by anything in the + checkpoint""" + model, train_dataset, _ = pytest.helpers.get_renate_module_mlp_and_data( + num_inputs=10, + num_outputs=10, + hidden_size=32, + num_hidden_layers=3, + train_num_samples=10, + test_num_samples=5, + ) + model2 = deepcopy(model) + model_updater = pytest.helpers.get_simple_updater(model, output_state_folder=tmpdir) + model_updater.update(train_dataset, task_id=defaults.TASK_ID) + + expected_model = deepcopy(model2) + model_updater = pytest.helpers.get_simple_updater(model2, input_state_folder=tmpdir) + for p1, p2 in zip( + expected_model.parameters(), + model_updater._learner._model.parameters(), + ): + assert torch.allclose(p1, p2) + + def test_deterministic_updater(): - # The behavior is always deterministic on CPU but it can become non-deterministic on GPU - # When run on CPU this test never fails so it is only useful when tests are run on GPU + # The behavior is always deterministic on CPU, but it can become non-deterministic on GPU + # When run on CPU this test never fails, so it is only useful when tests are run on GPU model1, train_dataset, test_data = pytest.helpers.get_renate_module_mlp_and_data( num_inputs=10, num_outputs=10, From 6f8acf6c48e652e80b29449ad5dadfdfd2c4ff21 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 14 Aug 2023 17:38:18 +0200 Subject: [PATCH 70/89] Enable Custom Grouping for DataIncrementalScenario (#368) --- benchmarks/run_benchmark.py | 1 + doc/benchmarking/renate_benchmarks.rst | 9 +- .../class_incremental_learning_cifar10_der.py | 2 +- .../renate_config.py | 2 +- examples/train_mlp_locally/renate_config.py | 2 +- .../benchmark/datasets/wild_time_data.py | 2 +- src/renate/benchmark/experiment_config.py | 81 ++++++++++++----- src/renate/benchmark/scenarios.py | 71 ++++++++++----- src/renate/cli/parsing_functions.py | 2 +- test/dummy_datasets.py | 38 +++++++- .../scenarios/class-incremental-2updates.json | 2 +- .../scenarios/class-incremental-5updates.json | 2 +- .../benchmark/test_experimentation_config.py | 59 ++++++------ test/renate/benchmark/test_scenarios.py | 89 ++++++++++++++++--- test/renate/cli/test_parsing_functions.py | 8 +- test/renate/renate_config_files/config.py | 2 +- .../renate_config_files/config_scenario.py | 4 +- 17 files changed, 275 insertions(+), 101 deletions(-) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 953e0819..8c9dd08f 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -112,6 +112,7 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): instance_type="ml.g4dn.xlarge", n_workers=1, max_time=args.max_time, + instance_max_time=args.max_time, seed=seed, job_name=args.job_name[:36], devices=1, diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index b97501a7..b65c1a2f 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -187,7 +187,7 @@ The first part contains all instances with classes 1 and 2, the second with clas .. code-block:: python config_space["scenario_name"] = "ClassIncrementalScenario" - config_space["class_groupings"] = ((1, 2), (3, 4)) + config_space["groupings"] = ((1, 2), (3, 4)) .. list-table:: Renate Scenario Overview :widths: 15 35 35 @@ -202,10 +202,13 @@ The first part contains all instances with classes 1 and 2, the second with clas Data is presented data by data, where the data could represent a domain or a time slice. - * :code:`num_tasks`: You can provide this argument if the different datasets are identified by ids 0 to `num_tasks`. This is the case for time-incremental datasets such as CLEAR or Wild-Time. - * :code:`data_ids`: List of data identifiers. Used for DomainNet to select order or subset of domains. + * :code:`data_ids`: Tuple of data identifiers. Used for DomainNet to select order or subset of domains, + e.g., ``("clipart", "infograph", "painting")``. + * :code:`groupings`: An alternative to data identifiers that in addition to defining the sequence + allows to combine different domains to one chunk, e.g., ``(("clipart", ), ("infograph", "painting"))``. * - :py:class:`~renate.benchmark.scenarios.ClassIncrementalScenario` - Creates data partitions by splitting the data according to class labels. - - * :code:`class_groupings`: Tuple of tuples containing the class labels, e.g., ``((1, ), (2, 3, 4))``. + - * :code:`groupings`: Tuple of tuples containing the class labels, e.g., ``((1, ), (2, 3, 4))``. * - :py:class:`~renate.benchmark.scenarios.FeatureSortingScenario` - Splits data into different tasks after sorting the data according to a specific feature. Can be used for image data as well. In that case channels are selected and we select according to diff --git a/examples/benchmarking/class_incremental_learning_cifar10_der.py b/examples/benchmarking/class_incremental_learning_cifar10_der.py index 74c182a7..b9b8a67e 100644 --- a/examples/benchmarking/class_incremental_learning_cifar10_der.py +++ b/examples/benchmarking/class_incremental_learning_cifar10_der.py @@ -21,7 +21,7 @@ "scenario_name": "ClassIncrementalScenario", "dataset_name": "CIFAR10", "val_size": 0, - "class_groupings": ((0, 1), (2, 3), (4, 5), (6, 7), (8, 9)), + "groupings": ((0, 1), (2, 3), (4, 5), (6, 7), (8, 9)), "num_outputs": 10, } diff --git a/examples/simple_classifier_cifar10/renate_config.py b/examples/simple_classifier_cifar10/renate_config.py index 1594dcac..e7b276a6 100644 --- a/examples/simple_classifier_cifar10/renate_config.py +++ b/examples/simple_classifier_cifar10/renate_config.py @@ -37,7 +37,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> ) class_incremental_scenario = ClassIncrementalScenario( data_module=data_module, - class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), + groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), chunk_id=chunk_id, ) return class_incremental_scenario diff --git a/examples/train_mlp_locally/renate_config.py b/examples/train_mlp_locally/renate_config.py index decf622d..88d6114a 100644 --- a/examples/train_mlp_locally/renate_config.py +++ b/examples/train_mlp_locally/renate_config.py @@ -28,7 +28,7 @@ def data_module_fn(data_path: str, chunk_id: int, seed: int = defaults.SEED) -> class_incremental_scenario = ClassIncrementalScenario( data_module=data_module, - class_groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), + groupings=((0, 1, 2, 3, 4), (5, 6, 7, 8, 9)), chunk_id=chunk_id, ) return class_incremental_scenario diff --git a/src/renate/benchmark/datasets/wild_time_data.py b/src/renate/benchmark/datasets/wild_time_data.py index b8260222..0d83730b 100644 --- a/src/renate/benchmark/datasets/wild_time_data.py +++ b/src/renate/benchmark/datasets/wild_time_data.py @@ -90,7 +90,7 @@ def setup(self) -> None: "time_step": available_time_steps(self._dataset_name)[self.data_id], "data_dir": self._data_path, "in_memory": self._dataset_name != "fmow", - "transform": None if self._dataset_name != "fmow" else lambda x: x, + "transform": None if self._dataset_name not in ["fmow", "yearbook"] else lambda x: x, } if self._tokenizer: kwargs["transform"] = lambda x: self._tokenizer(x, **(self._tokenizer_kwargs or {})) diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 7516579c..665d5a5b 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -11,7 +11,6 @@ from torchvision.transforms import transforms from transformers import AutoTokenizer from wild_time_data import default_transform -from wild_time_data.datasets import FMoW from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule from renate.benchmark.datasets.vision_datasets import ( @@ -165,12 +164,12 @@ def get_scenario( chunk_id: int, seed: int, num_tasks: Optional[int] = None, - class_groupings: Optional[Tuple[Tuple[int]]] = None, + groupings: Optional[Tuple[Tuple[int]]] = None, degrees: Optional[List[int]] = None, input_dim: Optional[Union[List[int], Tuple[int], int]] = None, feature_idx: Optional[int] = None, randomness: Optional[float] = None, - data_ids: Optional[List[Union[int, str]]] = None, + data_ids: Optional[Tuple[Union[int, str]]] = None, ) -> Scenario: """Function to create scenario based on name and arguments. @@ -180,8 +179,8 @@ def get_scenario( chunk_id: The data chunk to load in for the training or validation data. seed: A random seed to fix the created scenario. num_tasks: The total number of expected tasks for experimentation. - class_groupings: Used for scenario `ClassIncrementalScenario`. Partitions classes into - different chunks. + groupings: Used for scenario `ClassIncrementalScenario` to partition datasets into chunks by + class. Used by `DataIncrementalScenario` to group domains to chunks.. degrees: Used for scenario `ImageRotationScenario`. Rotations applied for each chunk. input_dim: Used for scenario `PermutationScenario`. Input dimensionality. feature_idx: Used for scenario `SoftSortingScenario`. Index of feature to sort by. @@ -195,12 +194,10 @@ def get_scenario( ValueError: If scenario name is unknown. """ if scenario_name == "ClassIncrementalScenario": - assert ( - class_groupings is not None - ), "Provide `class_groupings` for the class-incremental scenario." + assert groupings is not None, "Provide `groupings` for the class-incremental scenario." return ClassIncrementalScenario( data_module=data_module, - class_groupings=class_groupings, + groupings=groupings, chunk_id=chunk_id, ) if scenario_name == "IIDScenario": @@ -237,10 +234,14 @@ def get_scenario( seed=seed, ) if scenario_name == "DataIncrementalScenario": - if data_ids is None: + if data_ids is None and groupings is None: data_ids = [data_id for data_id in range(num_tasks)] return DataIncrementalScenario( - data_module=data_module, data_ids=data_ids, chunk_id=chunk_id, seed=seed + data_module=data_module, + chunk_id=chunk_id, + data_ids=data_ids, + groupings=groupings, + seed=seed, ) raise ValueError(f"Unknown scenario `{scenario_name}`.") @@ -259,7 +260,7 @@ def data_module_fn( dataset_name: str, val_size: float = 0.0, num_tasks: Optional[int] = None, - class_groupings: Optional[Tuple[Tuple[int]]] = None, + groupings: Optional[Tuple[Tuple[int]]] = None, degrees: Optional[Tuple[int]] = None, input_dim: Optional[Tuple[int]] = None, feature_idx: Optional[int] = None, @@ -290,7 +291,7 @@ def data_module_fn( chunk_id=chunk_id, seed=seed, num_tasks=num_tasks, - class_groupings=class_groupings, + groupings=groupings, degrees=degrees, input_dim=input_dim, feature_idx=feature_idx, @@ -317,10 +318,27 @@ def _get_normalize_transform(dataset_name): ) -def train_transform(dataset_name: str) -> Optional[Callable]: +def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Optional[Callable]: """Returns a transform function to be used in the training.""" if dataset_name == "fmow": return default_transform(dataset_name) + if dataset_name == "yearbook": + if ( + model_name is not None + and model_name.startswith("VisionTransformer") + and model_name != "VisionTransformerCIFAR" + ): + return transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize(224), + transforms.RandomHorizontalFlip(), + default_transform(dataset_name), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), + ] + ) + else: + return default_transform(dataset_name) if dataset_name in [ "MNIST", "FashionMNIST", @@ -337,21 +355,38 @@ def train_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ - transforms.ToTensor(), - transforms.Resize( - 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True - ), + transforms.Resize(224), transforms.RandomCrop(224), + transforms.ToTensor(), _get_normalize_transform(dataset_name), ] ) raise ValueError(f"Unknown dataset `{dataset_name}`.") -def test_transform(dataset_name: str) -> Optional[Callable]: +def test_transform( + dataset_name: str, + model_name: Optional[str] = None, +) -> Optional[Callable]: """Returns a transform function to be used for validation or testing.""" if dataset_name == "fmow": - return FMoW.default_transform + return default_transform(dataset_name) + if dataset_name == "yearbook": + if ( + model_name is not None + and model_name.startswith("VisionTransformer") + and model_name != "VisionTransformerCIFAR" + ): + return transforms.Compose( + [ + transforms.ToPILImage(), + transforms.Resize(224), + default_transform(dataset_name), + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), + ] + ) + else: + return default_transform(dataset_name) if dataset_name in [ "MNIST", "FashionMNIST", @@ -362,11 +397,9 @@ def test_transform(dataset_name: str) -> Optional[Callable]: if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ - transforms.ToTensor(), - transforms.Resize( - 224, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True - ), + transforms.Resize(224), transforms.CenterCrop(224), + transforms.ToTensor(), _get_normalize_transform(dataset_name), ] ) diff --git a/src/renate/benchmark/scenarios.py b/src/renate/benchmark/scenarios.py index 70725423..f67268b9 100644 --- a/src/renate/benchmark/scenarios.py +++ b/src/renate/benchmark/scenarios.py @@ -5,7 +5,7 @@ import numpy as np import torch -from torch.utils.data import Dataset, Subset +from torch.utils.data import ConcatDataset, Dataset, Subset from torchvision.transforms import Lambda, RandomRotation, ToPILImage from renate import defaults @@ -116,17 +116,17 @@ class ClassIncrementalScenario(Scenario): Args: data_module: The source RenateDataModule for the user data. chunk_id: The data chunk to load in for the training or validation data. - class_groupings: List of lists, describing the division of the classes for respective tasks. + groupings: Tuple of tuples, describing the division of the classes for respective tasks. """ def __init__( self, data_module: RenateDataModule, chunk_id: int, - class_groupings: Tuple[Tuple[int, ...], ...], + groupings: Tuple[Tuple[int, ...], ...], ) -> None: - super().__init__(data_module, len(class_groupings), chunk_id) - self._class_groupings = class_groupings + super().__init__(data_module, len(groupings), chunk_id) + self._class_groupings = groupings def setup(self) -> None: """Make assignments: val/train/test splits.""" @@ -397,42 +397,73 @@ class DataIncrementalScenario(Scenario): Args: data_module: The source :py:class:`~renate.data.data_module.RenateDataModule` for the user data. - data_ids: Unique identifier for each pre-defined dataset. chunk_id: The data chunk to load in for the training or validation data. + data_ids: Unique identifier for each pre-defined dataset. + groupings: Tuple of tuples that group different datasets associated to a ``data_id`` to + one dataset. seed: Seed used to fix random number generation. """ def __init__( self, data_module: RenateDataModule, - data_ids: List[Union[int, str]], chunk_id: int, + data_ids: Optional[Tuple[Union[int, str], ...]] = None, + groupings: Optional[Tuple[Tuple[int, ...], ...]] = None, seed: int = defaults.SEED, ) -> None: - super().__init__( - data_module=data_module, num_tasks=len(data_ids), chunk_id=chunk_id, seed=seed - ) if not isinstance(data_module, DataIncrementalDataModule): raise ValueError( "This scenario is only compatible with classes that extend " "`DataIncrementalDataModule`." ) - self._data_ids = data_ids + if data_ids is None and groupings is None: + raise ValueError( + "Either `data_ids` or `groupings` must be provided. None was provided." + ) + if data_ids is not None and groupings is not None: + raise ValueError( + "Either `data_ids` or `groupings` must be provided. Both were provided." + ) + if data_ids is not None: + self._groupings = tuple((data_id,) for data_id in data_ids) + else: + self._groupings = groupings + super().__init__( + data_module=data_module, num_tasks=len(self._groupings), chunk_id=chunk_id, seed=seed + ) def prepare_data(self) -> None: """Downloads datasets.""" - for data_id in self._data_ids: + for data_id in set(data_id for grouping in self._groupings for data_id in grouping): self._data_module.data_id = data_id self._data_module.prepare_data() + def _create_dataset_from_grouping(self, grouping) -> Tuple[Dataset, Optional[Dataset], Dataset]: + train_data = [] + val_data = [] + test_data = [] + for data_id in grouping: + self._data_module.data_id = data_id + self._data_module.setup() + train_data.append(self._data_module.train_data()) + if self._data_module.val_data() is not None: + val_data.append(self._data_module.val_data()) + test_data.append(self._data_module.test_data()) + if len(grouping) == 1: + return train_data[0], val_data[0] if len(val_data) else None, test_data[0] + return ( + ConcatDataset(train_data), + ConcatDataset(val_data) if len(val_data) else None, + ConcatDataset(test_data), + ) + def setup(self) -> None: """Sets up the scenario.""" - self._data_module.data_id = self._data_ids[self._chunk_id] super().setup() - self._train_data = self._data_module.train_data() - self._val_data = self._data_module.val_data() - self._test_data = [] - for data_id in self._data_ids: - self._data_module.data_id = data_id - self._data_module.setup() - self._test_data.append(self._data_module.test_data()) + self._train_data, self._val_data, _ = self._create_dataset_from_grouping( + self._groupings[self._chunk_id] + ) + self._test_data = [ + self._create_dataset_from_grouping(grouping)[2] for grouping in self._groupings + ] diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index cb79af5c..fff36f49 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -862,7 +862,7 @@ def raise_error(arg_type, arg_name): return argument_type -def to_dense_str(value: Union[bool, List, Tuple]) -> str: +def to_dense_str(value: Union[bool, List, Tuple, None]) -> str: """Converts a variable to string without empty spaces.""" return str(value).replace(" ", "") diff --git a/test/dummy_datasets.py b/test/dummy_datasets.py index 091d57d2..ca5d24f9 100644 --- a/test/dummy_datasets.py +++ b/test/dummy_datasets.py @@ -1,11 +1,12 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Optional +from typing import Callable, Optional, Tuple import torch from torch.utils.data import Dataset from renate import defaults +from renate.benchmark.datasets.base import DataIncrementalDataModule from renate.data.data_module import RenateDataModule from renate.utils.pytorch import get_generator @@ -69,3 +70,38 @@ def setup(self): self._train_data, self._val_data = self._split_train_val_data(train_data) self.X_test, self.y_test = self._get_random_data() self._test_data = DummyDataset(self.X_test, self.y_test, self._transform) + + +class DummyDataIncrementalDataModule(DataIncrementalDataModule): + """Simple dataset to test whether DataIncrementalDataModule or DataIncrementalScenario work. + + This data module gives access to 5 different datasets with ids [0, 1, 2, 3, 4]. + The labels of dataset with data_id=i are all i. + """ + + def __init__( + self, + data_id: int, + input_shape: Tuple[int, ...], + transform: Optional[Callable] = None, + val_size: float = defaults.VALIDATION_SIZE, + seed: int = defaults.SEED, + ): + super().__init__(data_path="", data_id=data_id, val_size=val_size, seed=seed) + rng = get_generator(self._seed) + + def get_data(i): + return DummyDataset( + data=torch.rand((100, *input_shape), generator=rng, dtype=torch.float32), + targets=torch.zeros(100) + i, + transform=transform, + ) + + self._datasets = [[get_data(i) for _ in range(2)] for i in range(5)] + + def prepare_data(self) -> None: + pass + + def setup(self): + self._train_data, self._test_data = self._datasets[self.data_id] + self._train_data, self._val_data = self._split_train_val_data(self._train_data) diff --git a/test/integration_tests/configs/scenarios/class-incremental-2updates.json b/test/integration_tests/configs/scenarios/class-incremental-2updates.json index e4f58c46..4ad4e551 100644 --- a/test/integration_tests/configs/scenarios/class-incremental-2updates.json +++ b/test/integration_tests/configs/scenarios/class-incremental-2updates.json @@ -1,7 +1,7 @@ { "val_size": 0.05, "scenario_name": "ClassIncrementalScenario", - "class_groupings": [[0, 1], [2, 3]], + "groupings": [[0, 1], [2, 3]], "num_tasks": 2, "max_epochs": 5 } diff --git a/test/integration_tests/configs/scenarios/class-incremental-5updates.json b/test/integration_tests/configs/scenarios/class-incremental-5updates.json index cf1879ab..b297ac18 100644 --- a/test/integration_tests/configs/scenarios/class-incremental-5updates.json +++ b/test/integration_tests/configs/scenarios/class-incremental-5updates.json @@ -1,7 +1,7 @@ { "val_size": 0.05, "scenario_name": "ClassIncrementalScenario", - "class_groupings": [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], + "groupings": [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]], "num_tasks": 5, "max_epochs": 50 } diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 512c40d0..af5cf9d1 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -5,7 +5,7 @@ from torch.optim import AdamW, SGD from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR from torchmetrics.classification import MulticlassAccuracy -from torchvision.transforms import Compose, Normalize +from torchvision.transforms import Compose, Normalize, ToTensor from renate.benchmark import experiment_config from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule @@ -157,7 +157,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "pretrained_model_name": "distilbert-base-uncased", "input_column": "text", "target_column": "coarse_label", - "class_groupings": ((0, 1), (2, 3), (4, 5)), + "groupings": ((0, 1), (2, 3), (4, 5)), }, ClassIncrementalScenario, 3, @@ -219,7 +219,14 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): ( "DataIncrementalScenario", "DomainNet", - {"data_ids": ["clipart", "infograph"]}, + {"data_ids": ("clipart", "infograph")}, + DataIncrementalScenario, + 2, + ), + ( + "DataIncrementalScenario", + "DomainNet", + {"groupings": (("clipart", "infograph"), "painting")}, DataIncrementalScenario, 2, ), @@ -231,10 +238,11 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "permutation", "feature_sorting", "hue_shift", - "time_with_clear", + "data_incremental with CLEAR", "wild_time_text_with_tokenizer", "wild_time_image_all_tasks", - "domainnet", + "domainnet_by data_id", + "domainnet by groupings", ], ) @pytest.mark.parametrize("val_size", (0, 0.5), ids=["no_val", "val"]) @@ -258,7 +266,7 @@ def test_data_module_fn( ) assert isinstance(scenario, expected_scenario_class) if expected_scenario_class == ClassIncrementalScenario: - assert scenario._class_groupings == scenario_kwargs["class_groupings"] + assert scenario._class_groupings == scenario_kwargs["groupings"] elif expected_scenario_class == FeatureSortingScenario: assert scenario._feature_idx == scenario_kwargs["feature_idx"] assert scenario._randomness == scenario_kwargs["randomness"] @@ -273,30 +281,27 @@ def test_data_module_fn( @pytest.mark.parametrize( - "dataset_name,use_transforms,test_compose", + "dataset_name,expected_train_transform_class, expected_test_transform_class,model_name", ( - ("MNIST", False, False), - ("FashionMNIST", False, False), - ("CIFAR10", True, False), - ("CIFAR100", True, False), - ("CLEAR10", True, True), - ("DomainNet", True, True), - ("hfd-rotten_tomatoes", False, False), - ("fmow", True, True), + ("MNIST", type(None), type(None), "ResNet18CIFAR"), + ("FashionMNIST", type(None), type(None), "ResNet18CIFAR"), + ("CIFAR10", Compose, Normalize, "ResNet18CIFAR"), + ("CIFAR100", Compose, Normalize, "ResNet18CIFAR"), + ("CLEAR10", Compose, Compose, "ResNet18"), + ("DomainNet", Compose, Compose, "VisionTransformerB16"), + ("hfd-rotten_tomatoes", type(None), type(None), "HuggingFaceTransformer"), + ("fmow", Compose, Compose, "ResNet18"), + ("yearbook", ToTensor, ToTensor, "ResNet18CIFAR"), + ("yearbook", Compose, Compose, "VisionTransformerB16"), ), ) -def test_transforms(dataset_name, use_transforms, test_compose): - train_preprocessing = train_transform(dataset_name) - test_preprocessing = experiment_config.test_transform(dataset_name) - if use_transforms: - assert isinstance(train_preprocessing, Compose) - if test_compose: - assert isinstance(test_preprocessing, Compose) - else: - assert isinstance(test_preprocessing, Normalize) - else: - assert train_preprocessing is None - assert test_preprocessing is None +def test_transforms( + dataset_name, expected_train_transform_class, expected_test_transform_class, model_name +): + train_preprocessing = train_transform(dataset_name, model_name) + test_preprocessing = experiment_config.test_transform(dataset_name, model_name) + assert isinstance(train_preprocessing, expected_train_transform_class) + assert isinstance(test_preprocessing, expected_test_transform_class) def test_transforms_fails_for_unknown_dataset(): diff --git a/test/renate/benchmark/test_scenarios.py b/test/renate/benchmark/test_scenarios.py index 60950282..5449d43f 100644 --- a/test/renate/benchmark/test_scenarios.py +++ b/test/renate/benchmark/test_scenarios.py @@ -7,7 +7,7 @@ import torch from torchvision.transforms.functional import rotate -from dummy_datasets import DummyTorchVisionDataModule +from dummy_datasets import DummyDataIncrementalDataModule, DummyTorchVisionDataModule from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.benchmark.scenarios import ( ClassIncrementalScenario, @@ -32,7 +32,7 @@ { "num_tasks": 3, "chunk_id": 4, # Wrong chunk id - "class_groupings": [[0, 1, 2], [3, 5, 9], [4, 6, 7, 8]], + "groupings": ((0, 1, 2), (3, 5, 9), (4, 6, 7, 8)), }, ], [PermutationScenario, {"num_tasks": 3, "chunk_id": 2}], # Missing input dim @@ -49,29 +49,29 @@ def test_failing_to_init(tmpdir, scenario_cls, kwargs): def test_class_incremental_scenario(): data_module = DummyTorchVisionDataModule(val_size=0.3, seed=42) - class_groupings = [[0, 1, 3], [2], [3, 4]] + groupings = ((0, 1, 3), (2,), (3, 4)) train_data_class_counts = Counter({3: 16, 4: 15, 0: 15, 2: 13, 1: 11}) val_data_class_counts = Counter({1: 9, 2: 7, 4: 5, 0: 5, 3: 4}) test_data_class_counts = Counter({0: 20, 1: 20, 2: 20, 3: 20, 4: 20}) - for i in range(len(class_groupings)): + for i in range(len(groupings)): scenario = ClassIncrementalScenario( - data_module=data_module, class_groupings=class_groupings, chunk_id=i + data_module=data_module, groupings=groupings, chunk_id=i ) scenario.prepare_data() scenario.setup() train_data = scenario.train_data() val_data = scenario.val_data() - assert len(train_data) == sum([train_data_class_counts[c] for c in class_groupings[i]]) - assert len(val_data) == sum([val_data_class_counts[c] for c in class_groupings[i]]) + assert len(train_data) == sum([train_data_class_counts[c] for c in groupings[i]]) + assert len(val_data) == sum([val_data_class_counts[c] for c in groupings[i]]) for j, test_data in enumerate(scenario.test_data()): - assert len(test_data) == sum([test_data_class_counts[c] for c in class_groupings[j]]) + assert len(test_data) == sum([test_data_class_counts[c] for c in groupings[j]]) -def test_class_incremental_scenario_class_grouping_error(): +def test_class_incremental_scenario_groupings_error(): """Classes selected do not exist in data.""" scenario = ClassIncrementalScenario( data_module=DummyTorchVisionDataModule(val_size=0.3, seed=42), - class_groupings=((0, 1, 3), (2, 200)), + groupings=((0, 1, 3), (2, 200)), chunk_id=0, ) scenario.prepare_data() @@ -79,12 +79,77 @@ def test_class_incremental_scenario_class_grouping_error(): scenario.setup() -def test_data_incremental_scenario_init_error(): +def test_data_incremental_scenario_grouping(): + """Test whether grouping in DataIncrementalScenario works""" + scenario = DataIncrementalScenario( + data_module=DummyDataIncrementalDataModule(0, (10, 1), val_size=0.2), + chunk_id=0, + groupings=((0, 1), (2,)), + ) + scenario.prepare_data() + scenario.setup() + assert set(int(scenario.train_data()[i][1]) for i in range(80)) == {0} + assert set(int(scenario.train_data()[i][1]) for i in range(80, 160)) == {1} + assert set(int(scenario.val_data()[i][1]) for i in range(20)) == {0} + assert set(int(scenario.val_data()[i][1]) for i in range(20, 40)) == {1} + assert set(int(scenario.test_data()[0][i][1]) for i in range(100)) == {0} + assert set(int(scenario.test_data()[0][i][1]) for i in range(100, 200)) == {1} + assert set(int(scenario.test_data()[1][i][1]) for i in range(100)) == {2} + assert len(scenario.train_data()) == 160 + assert len(scenario.val_data()) == 40 + assert len(scenario.test_data()) == 2 + assert len(scenario.test_data()[0]) == 200 + assert len(scenario.test_data()[1]) == 100 + + +def test_data_incremental_scenario_data_ids(): + """Test whether data_ids in DataIncrementalScenario works""" + scenario = DataIncrementalScenario( + data_module=DummyDataIncrementalDataModule(0, (10, 1), val_size=0.2), + chunk_id=1, + data_ids=(3, 4), + ) + scenario.prepare_data() + scenario.setup() + assert set(int(scenario.train_data()[i][1]) for i in range(80)) == {4} + assert set(int(scenario.val_data()[i][1]) for i in range(20)) == {4} + assert set(int(scenario.test_data()[0][i][1]) for i in range(100)) == {3} + assert set(int(scenario.test_data()[1][i][1]) for i in range(100)) == {4} + assert len(scenario.train_data()) == 80 + assert len(scenario.val_data()) == 20 + assert len(scenario.test_data()) == 2 + assert len(scenario.test_data()[0]) == 100 + assert len(scenario.test_data()[1]) == 100 + + +def test_data_incremental_scenario_data_module_error(): """Check that DataIncrementalScenario raises Exception for unsupported DataModule.""" with pytest.raises(ValueError, match=r"This scenario is only compatible with*"): DataIncrementalScenario( data_module=DummyTorchVisionDataModule(), - data_ids=[0, 1], + data_ids=(0, 1), + chunk_id=0, + ) + + +@pytest.mark.parametrize( + "data_ids,groupings,expected_error_message", + ( + ( + (0, 1), + ((0,), (1,)), + "Either `data_ids` or `groupings` must be provided. Both were provided.", + ), + (None, None, "Either `data_ids` or `groupings` must be provided. None was provided."), + ), +) +def test_data_incremental_scenario_wrong_input(data_ids, groupings, expected_error_message): + """Scenario expect either data_ids or groupings.""" + with pytest.raises(ValueError, match=expected_error_message): + DataIncrementalScenario( + data_module=DummyDataIncrementalDataModule(data_id=0, input_shape=(2, 2)), + data_ids=data_ids, + groupings=groupings, chunk_id=0, ) diff --git a/test/renate/cli/test_parsing_functions.py b/test/renate/cli/test_parsing_functions.py index 24db77ea..6b3da15e 100644 --- a/test/renate/cli/test_parsing_functions.py +++ b/test/renate/cli/test_parsing_functions.py @@ -90,7 +90,7 @@ def test_get_function_args(all_args, ignore_args): "data_path", "val_size", "seed", - "class_groupings", + "groupings", "optional_tuple", "optional_float", "list_param", @@ -117,7 +117,7 @@ def test_get_function_args(all_args, ignore_args): "default": 0, "true_type": int, }, - "class_groupings": { + "groupings": { "type": str, "argument_group": CUSTOM_ARGS_GROUP, "default": "((0,1),(2,3,4))", @@ -207,7 +207,7 @@ def test_get_fn_kwargs_helper_functions(): and the Python function.""" expected_data_module_kwargs = { "data_path": "home/data/path", - "class_groupings": ((1, 2), (3, 4)), + "groupings": ((1, 2), (3, 4)), "optional_float": None, "bool_param": False, } @@ -215,7 +215,7 @@ def test_get_fn_kwargs_helper_functions(): "data_path": expected_data_module_kwargs["data_path"], "model_state_url": "home/model/state", "unused_config": 1, - "class_groupings": to_dense_str(expected_data_module_kwargs["class_groupings"]), + "groupings": to_dense_str(expected_data_module_kwargs["groupings"]), "optional_float": to_dense_str(expected_data_module_kwargs["optional_float"]), "bool_param": to_dense_str(expected_data_module_kwargs["bool_param"]), "num_outputs": 10, diff --git a/test/renate/renate_config_files/config.py b/test/renate/renate_config_files/config.py index 5b7a6900..7152ff32 100644 --- a/test/renate/renate_config_files/config.py +++ b/test/renate/renate_config_files/config.py @@ -22,7 +22,7 @@ def data_module_fn( data_path: str, val_size: float = 0.0, seed: int = 0, - class_groupings: Tuple[Tuple[int]] = ((0, 1), (2, 3, 4)), + groupings: Tuple[Tuple[int]] = ((0, 1), (2, 3, 4)), optional_tuple: Optional[Tuple[float]] = None, optional_float: Optional[float] = None, list_param: list = [1, 2], diff --git a/test/renate/renate_config_files/config_scenario.py b/test/renate/renate_config_files/config_scenario.py index dac20391..df3d41da 100644 --- a/test/renate/renate_config_files/config_scenario.py +++ b/test/renate/renate_config_files/config_scenario.py @@ -23,13 +23,13 @@ def data_module_fn( chunk_id: Optional[int] = None, val_size: float = 0.0, seed: int = 0, - class_groupings: Tuple[Tuple[int, ...], ...] = ((0, 1), (2, 3, 4)), + groupings: Tuple[Tuple[int, ...], ...] = ((0, 1), (2, 3, 4)), ) -> Scenario: data_module = DummyTorchVisionDataModule(transform=None, val_size=val_size, seed=seed) return ClassIncrementalScenario( data_module=data_module, chunk_id=chunk_id, - class_groupings=class_groupings, + groupings=groupings, ) From 41bd5b9904874e3d3d7b88a1e806d3f1f2828f78 Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 14 Aug 2023 17:39:10 +0200 Subject: [PATCH 71/89] MultiText dataset Added to Benchmarking (#366) --- benchmarks/run_benchmark.py | 11 ++++ doc/benchmarking/renate_benchmarks.rst | 6 +- src/renate/benchmark/datasets/nlp_datasets.py | 65 +++++-------------- src/renate/benchmark/experiment_config.py | 18 ++++- src/renate/defaults.py | 2 - .../benchmark/test_experimentation_config.py | 3 +- 6 files changed, 49 insertions(+), 56 deletions(-) diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py index 8c9dd08f..edbf454f 100644 --- a/benchmarks/run_benchmark.py +++ b/benchmarks/run_benchmark.py @@ -59,6 +59,12 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): default=5 * 24 * 3600, help="Maximum execution time.", ) + parser.add_argument( + "--use-prior-task-weights", + action="store_true", + help="Do not reset model weights (Joint and GDumb only)", + ) + args = parser.parse_args() current_folder = Path(os.path.dirname(__file__)) configs_folder = current_folder / "experiment_configs" @@ -73,6 +79,11 @@ def load_config(scenario_file, model_file, updater_file, dataset_file): os.path.join(configs_folder, "updaters", benchmark_config["updater"]), os.path.join(configs_folder, "datasets", benchmark_config["dataset"]), ) + if args.use_prior_task_weights: + if config_space["updater"] in ["Joint", "GDumb"]: + config_space["reset"] = False + else: + raise ValueError("Please use `use-prior-task-weights` only for Joint or GDumb.") config_space["max_epochs"] = int(args.budget_factor * config_space["max_epochs"]) if "learning_rate_scheduler_step_size" in config_space: config_space["learning_rate_scheduler_step_size"] = int( diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index b65c1a2f..0daf12ee 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -161,6 +161,10 @@ The following table contains the list of supported datasets. - Image Classification - 60k train, 10k test, 10 classes, image shape 28x28x1 - Li Deng: The MNIST Database of Handwritten Digit Images for Machine Learning Research. IEEE Signal Processing Magazine. 2012. + * - MultiText + - Text Classification + - 115k train, 7.6k test, access to one of four datasets: ag_news, yelp_review_full, dbpedia_14, yahoo_answers_topics + - Please refer to `the official documentation `__. * - yearbook - Image Classification: gender identification in yearbook photos. - ~33k train, ~4k test, 2 classes, years 1930-2013, image shape 32x32x1 @@ -198,7 +202,7 @@ The first part contains all instances with classes 1 and 2, the second with clas - Settings * - :py:class:`~renate.benchmark.scenarios.DataIncrementalScenario` - Used in combination only with :py:class:`~renate.benchmark.datasets.base.DataIncrementalDataModule`, - e.g., Wild-Time datasets, CLEAR, or DomainNet. + e.g., Wild-Time datasets, CLEAR, MultiText, or DomainNet. Data is presented data by data, where the data could represent a domain or a time slice. - * :code:`num_tasks`: You can provide this argument if the different datasets are identified by ids 0 to `num_tasks`. This is the case for time-incremental datasets such as CLEAR or Wild-Time. diff --git a/src/renate/benchmark/datasets/nlp_datasets.py b/src/renate/benchmark/datasets/nlp_datasets.py index 5abe1fc8..623eb928 100644 --- a/src/renate/benchmark/datasets/nlp_datasets.py +++ b/src/renate/benchmark/datasets/nlp_datasets.py @@ -142,14 +142,13 @@ def tokenize_fn(batch): class MultiTextDataModule(DataIncrementalDataModule): """ Inspired by the dataset used in "Episodic Memory in Lifelong Language Learning" - by d’Autume et al. this is a collection of five different datasets that we call domains: + by d’Autume et al. this is a collection of four different datasets that we call domains: AGNews, Yelp, DBPedia and Yahoo Answers. The output space if the union of the output space of all the domains. The dataset has 33 classes: 4 from AGNews, 5 from Yelp, 14 from DBPedia, and 10 from Yahoo. - The maximum allowed size for the training set is 115000 and for the test set is 7600. - Each domain will have the same fixed size. + The largest available size for the training set is 115000 and for the test set is 7600. Args: data_path: The path to the folder where the data files will be downloaded to. @@ -173,43 +172,14 @@ class MultiTextDataModule(DataIncrementalDataModule): "dbpedia_14": ["content", "label"], "yahoo_answers_topics": ["question_title", "topic"], } - _labels_map = { - "ag_news0": 0, - "ag_news1": 1, - "ag_news2": 2, - "ag_news3": 3, - "yelp_review_full0": 4, - "yelp_review_full1": 5, - "yelp_review_full2": 6, - "yelp_review_full3": 7, - "yelp_review_full4": 8, - "dbpedia_140": 9, - "dbpedia_141": 10, - "dbpedia_142": 11, - "dbpedia_143": 12, - "dbpedia_144": 13, - "dbpedia_145": 14, - "dbpedia_146": 15, - "dbpedia_147": 16, - "dbpedia_148": 17, - "dbpedia_149": 18, - "dbpedia_1410": 19, - "dbpedia_1411": 20, - "dbpedia_1412": 21, - "dbpedia_1413": 22, - "yahoo_answers_topics0": 23, - "yahoo_answers_topics1": 24, - "yahoo_answers_topics2": 25, - "yahoo_answers_topics3": 26, - "yahoo_answers_topics4": 27, - "yahoo_answers_topics5": 28, - "yahoo_answers_topics6": 29, - "yahoo_answers_topics7": 30, - "yahoo_answers_topics8": 31, - "yahoo_answers_topics9": 32, + _label_offset = { + "ag_news": 0, + "yelp_review_full": 4, + "dbpedia_14": 9, + "yahoo_answers_topics": 23, } - domains = _multi_dataset_info.keys() + domains = list(_multi_dataset_info) def __init__( self, @@ -217,8 +187,8 @@ def __init__( tokenizer: transformers.PreTrainedTokenizer, data_id: str, tokenizer_kwargs: Optional[Dict[str, Any]] = None, - train_size: int = defaults.SMALL_TRAIN_SET_SIZE, - test_size: int = defaults.SMALL_TEST_SET_SIZE, + train_size: int = 115000, + test_size: int = 7600, val_size: float = defaults.VALIDATION_SIZE, seed: int = defaults.SEED, ): @@ -229,7 +199,7 @@ def __init__( self._train_size = train_size if test_size > 7600: - raise ValueError("The `test_size` must be smaller than 7600") + raise ValueError("The `test_size` must be smaller than or equal to 7600") self._test_size = test_size self._tokenizer = tokenizer @@ -244,26 +214,25 @@ def __init__( def prepare_data(self) -> None: """Download dataset.""" - - for split in ["train", "test"] + (["validation"] if self._val_size > 0 else []): + for split in ["train", "test"]: load_dataset(self.data_id, split=split, cache_dir=self._data_path) def setup(self) -> None: """Set up train, test and val datasets.""" - rnd_gen = torch.Generator().manual_seed(self._seed) def preprocess(example, text_field_name, label_field_name): return { **self._tokenizer(example[text_field_name], **self._tokenizer_kwargs), - "label": self._labels_map[f"{self.data_id}{example[label_field_name]}"], + "label": example[label_field_name] + + MultiTextDataModule._label_offset[self.data_id], } def get_split(split_name): dataset = load_dataset(self.data_id, split=split_name, cache_dir=self._data_path) # the following is hack needed because the output space of the new dataset is # the union of the output spaces of the single datasets - # HF datasets check for the max label id and we need to make sure we update that + # HF datasets check for the max label id, and we need to make sure we update that # without this change the setup will fail with a value error (label id > max labels) new_features = dataset.features.copy() new_features[self._multi_dataset_info[self.data_id][1]] = datasets.ClassLabel( @@ -278,7 +247,6 @@ def get_split(split_name): set_size = self._test_size rnd_idx = torch.randint( - low=0, high=len(dataset), size=(set_size,), generator=rnd_gen, @@ -300,6 +268,5 @@ def get_split(split_name): return _InputTargetWrapper(dataset) self._train_data = get_split("train") + self._train_data, self._val_data = self._split_train_val_data(self._train_data) self._test_data = get_split("test") - if self._val_size > 0: - self._val_data = get_split("validation") diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 665d5a5b..6069574e 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -12,7 +12,7 @@ from transformers import AutoTokenizer from wild_time_data import default_transform -from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule from renate.benchmark.datasets.vision_datasets import ( CLEARDataModule, DomainNetDataModule, @@ -118,6 +118,9 @@ def get_data_module( input_column: Optional[str], target_column: Optional[str], ) -> RenateDataModule: + tokenizer = None + if pretrained_model_name is not None: + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) if dataset_name in TorchVisionDataModule.dataset_dict: return TorchVisionDataModule( data_path, dataset_name=dataset_name, val_size=val_size, seed=seed @@ -134,7 +137,7 @@ def get_data_module( "seed": seed, } if pretrained_model_name is not None: - data_module_kwargs["tokenizer"] = AutoTokenizer.from_pretrained(pretrained_model_name) + data_module_kwargs["tokenizer"] = tokenizer return WildTimeDataModule(**data_module_kwargs) if dataset_name == "DomainNet": return DomainNetDataModule( @@ -144,8 +147,15 @@ def get_data_module( val_size=val_size, seed=seed, ) + if dataset_name == "MultiText": + return MultiTextDataModule( + data_path=data_path, + tokenizer=tokenizer, + data_id="ag_news", + val_size=val_size, + seed=seed, + ) if dataset_name.startswith("hfd-"): - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) return HuggingFaceTextDataModule( data_path=data_path, dataset_name=dataset_name[4:], @@ -342,6 +352,7 @@ def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Opti if dataset_name in [ "MNIST", "FashionMNIST", + "MultiText", ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None if dataset_name in ["CIFAR10", "CIFAR100"]: @@ -390,6 +401,7 @@ def test_transform( if dataset_name in [ "MNIST", "FashionMNIST", + "MultiText", ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None if dataset_name in ["CIFAR10", "CIFAR100"]: diff --git a/src/renate/defaults.py b/src/renate/defaults.py index c41d7499..9a53861d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -106,8 +106,6 @@ # Benchmark datasets/models TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True} -SMALL_TRAIN_SET_SIZE = 1000 -SMALL_TEST_SET_SIZE = 1000 def scheduler(config_space: Dict[str, Any], mode: str, metric: str): diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index af5cf9d1..062420dc 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -8,7 +8,7 @@ from torchvision.transforms import Compose, Normalize, ToTensor from renate.benchmark import experiment_config -from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule +from renate.benchmark.datasets.nlp_datasets import HuggingFaceTextDataModule, MultiTextDataModule from renate.benchmark.datasets.vision_datasets import CLEARDataModule, TorchVisionDataModule from renate.benchmark.experiment_config import ( data_module_fn, @@ -93,6 +93,7 @@ def test_model_fn_fails_for_unknown_model(): "text", "label", ), + ("MultiText", MultiTextDataModule, "distilbert-base-uncased", None, None), ), ) def test_get_data_module( From c4bcba3fbc03609e24af33262a3a00a1645e4d85 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Tue, 15 Aug 2023 14:24:46 +0200 Subject: [PATCH 72/89] Masking of logits of irrelevant classes (#364) --- src/renate/benchmark/experimentation.py | 4 ++ src/renate/cli/parsing_functions.py | 11 +++- src/renate/defaults.py | 1 + src/renate/training/training.py | 4 ++ .../updaters/avalanche/model_updater.py | 16 ++++++ src/renate/updaters/experimental/er.py | 18 +++++++ .../updaters/experimental/fine_tuning.py | 2 + src/renate/updaters/experimental/gdumb.py | 2 + src/renate/updaters/experimental/joint.py | 2 + .../updaters/experimental/offline_er.py | 7 +++ .../updaters/experimental/repeated_distill.py | 6 +++ src/renate/updaters/learner.py | 36 ++++++++++++- src/renate/updaters/model_updater.py | 2 + src/renate/utils/misc.py | 27 +++++++++- src/renate/utils/pytorch.py | 33 +++++++++++- test/renate/utils/test_misc.py | 54 ++++++++++++++++++- test/renate/utils/test_pytorch.py | 50 ++++++++++++++++- 17 files changed, 269 insertions(+), 6 deletions(-) diff --git a/src/renate/benchmark/experimentation.py b/src/renate/benchmark/experimentation.py index 31722498..bca37d51 100644 --- a/src/renate/benchmark/experimentation.py +++ b/src/renate/benchmark/experimentation.py @@ -183,6 +183,10 @@ def execute_experiment_job( devices: Number of devices to use. deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. In this function this parameter is set to True by default. + gradient_clip_val: The value at which to clip gradients. Passing None disables it. + `More details `__ + gradient_clip_algorithm: The gradient clipping algorithm to use. Can be norm or value. + `More details `__ job_name: Name of the experiment job. strategy: Name of the distributed training strategy to use. `More details `__ diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index fff36f49..4dae8f78 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -48,7 +48,7 @@ def get_updater_and_learner_kwargs( """Returns the model updater class and the keyword arguments for the learner.""" if args.updater.startswith("Avalanche-") and find_spec("avalanche", None) is None: raise ImportError("Avalanche is not installed. Please run `pip install Renate[avalanche]`.") - learner_args = ["batch_size", "seed"] + learner_args = ["batch_size", "seed", "mask_unused_classes"] base_er_args = learner_args + [ "loss_weight", "ema_memory_update_gamma", @@ -324,6 +324,15 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: "choices": ["norm", "value", None], "argument_group": OPTIONAL_ARGS_GROUP, }, + "mask_unused_classes": { + "default": str(defaults.MASK_UNUSED_CLASSES), + "type": str, + "choices": ["True", "False"], + "help": "Whether to use a class mask to kill the unused logits. Useful possibly for " + "class incremental learning methods. ", + "argument_group": OPTIONAL_ARGS_GROUP, + "true_type": bool, + }, "prepare_data": { "type": str, "default": "True", diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 9a53861d..9025140d 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -45,6 +45,7 @@ FRAMEWORK_VERSION = "1.13.1" TASK_ID = "default_task" +MASK_UNUSED_CLASSES = False WORKING_DIRECTORY = "renate_working_dir" LOGGER = TensorBoardLogger LOGGER_KWARGS = { diff --git a/src/renate/training/training.py b/src/renate/training/training.py index 89c0cbb8..ae993ee5 100644 --- a/src/renate/training/training.py +++ b/src/renate/training/training.py @@ -144,6 +144,10 @@ def run_training_job( precision: Type of bit precision to use. `More details `__ deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU. + gradient_clip_val: The value at which to clip gradients. Passing None disables it. + `More details `__ + gradient_clip_algorithm: The gradient clipping algorithm to use. Can be norm or value. + `More details `__ job_name: Prefix for the name of the SageMaker training job. """ assert ( diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index 68438f61..e5fb0933 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -1,6 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import logging +import warnings from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Type @@ -47,6 +48,13 @@ class AvalancheModelUpdater(SingleTrainingLoopUpdater): _report = Reporter() + def __init__(self, *args, **kwargs): + if kwargs.get("mask_unused_classes", False) is True: + warnings.warn( + "Avalanche model updaters do not support mask_unused_classes. Ignoring it." + ) + super().__init__(*args, **kwargs) + def _load_learner( self, learner_class: Type[Learner], @@ -276,6 +284,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, @@ -310,6 +319,7 @@ def __init__( precision=precision, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm, + mask_unused_classes=mask_unused_classes, ) @@ -344,6 +354,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, @@ -377,6 +388,7 @@ def __init__( precision=precision, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm, + mask_unused_classes=mask_unused_classes, ) @@ -412,6 +424,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, @@ -446,6 +459,7 @@ def __init__( precision=precision, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm, + mask_unused_classes=mask_unused_classes, ) @@ -480,6 +494,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -513,4 +528,5 @@ def __init__( precision=precision, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index f9e6da71..82b9a1da 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -29,6 +29,7 @@ ShrinkAndPerturbReinitializationComponent, ) from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.misc import maybe_populate_mask_and_ignore_logits from renate.utils.pytorch import move_tensors_to_device @@ -140,6 +141,13 @@ def training_step( batch_memory = self._sample_from_buffer(device=step_output["loss"].device) (inputs_memory, _), metadata_memory = batch_memory outputs_memory = self(inputs_memory) + + outputs_memory, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, + self._class_mask, + self._classes_in_current_task, + outputs_memory, + ) intermediate_representation_memory = ( self._model.get_intermediate_representation() ) @@ -554,6 +562,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -594,6 +603,7 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) @@ -635,6 +645,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -676,6 +687,7 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) @@ -718,6 +730,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -760,6 +773,7 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) @@ -805,6 +819,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -850,6 +865,7 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) @@ -901,6 +917,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -952,4 +969,5 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/fine_tuning.py b/src/renate/updaters/experimental/fine_tuning.py index f31295dd..c9b82c70 100644 --- a/src/renate/updaters/experimental/fine_tuning.py +++ b/src/renate/updaters/experimental/fine_tuning.py @@ -44,6 +44,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, @@ -77,4 +78,5 @@ def __init__( precision=precision, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index b7cee0b6..623046a8 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -132,6 +132,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -168,4 +169,5 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/joint.py b/src/renate/updaters/experimental/joint.py index 22d38bd6..7d6ca1d1 100644 --- a/src/renate/updaters/experimental/joint.py +++ b/src/renate/updaters/experimental/joint.py @@ -121,6 +121,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "batch_size": batch_size, @@ -153,4 +154,5 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 17c3e0e1..78ac2dd0 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -17,6 +17,7 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.misc import maybe_populate_mask_and_ignore_logits from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors @@ -113,6 +114,10 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) inputs = cat_nested_tensors((inputs, inputs_mem), 0) targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) + + outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs + ) loss = self._loss_fn(outputs, targets) if "memory" in batch: loss_current = loss[:batch_size_current].mean() @@ -169,6 +174,7 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = { "memory_size": memory_size, @@ -206,4 +212,5 @@ def __init__( deterministic_trainer=deterministic_trainer, gradient_clip_algorithm=gradient_clip_algorithm, gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) diff --git a/src/renate/updaters/experimental/repeated_distill.py b/src/renate/updaters/experimental/repeated_distill.py index dc3264f1..d0662688 100644 --- a/src/renate/updaters/experimental/repeated_distill.py +++ b/src/renate/updaters/experimental/repeated_distill.py @@ -123,6 +123,9 @@ def __init__( seed: Optional[int] = None, early_stopping_enabled=False, deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): learner_kwargs = {"memory_size": memory_size, "batch_size": batch_size, "seed": seed} super().__init__( @@ -152,6 +155,9 @@ def __init__( early_stopping_enabled=early_stopping_enabled, logged_metrics=logged_metrics, deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, ) def update( diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index b84b27ea..fd9d4123 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -20,7 +20,8 @@ from renate.memory import DataBuffer, InfiniteBuffer, ReservoirBuffer from renate.models import RenateModule from renate.types import NestedTensors -from renate.utils.pytorch import get_generator +from renate.utils.misc import maybe_populate_mask_and_ignore_logits +from renate.utils.pytorch import get_generator, unique_classes class RenateLightningModule(LightningModule, abc.ABC): @@ -44,6 +45,8 @@ class RenateLightningModule(LightningModule, abc.ABC): batch_size: Training batch size. logged_metrics: Metrics logged additional to the default ones. seed: See :func:`renate.models.utils.get_generator`. + mask_unused_classes: Flag to use if logits corresponding to unused classes are to be ignored + in the loss computation. Possibly useful for class incremental learning. """ def __init__( @@ -56,6 +59,7 @@ def __init__( batch_size: int = defaults.BATCH_SIZE, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, seed: int = defaults.SEED, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ) -> None: super().__init__() self._model = model @@ -65,6 +69,11 @@ def __init__( self._learning_rate_scheduler_interval = learning_rate_scheduler_interval self._batch_size = batch_size self._seed = seed + self._mask_unused_classes = mask_unused_classes + + self._class_mask = None + self._classes_in_current_task = None + self._task_id: str = defaults.TASK_ID self._train_dataset: Optional[Dataset] = None self._val_dataset: Optional[Dataset] = None @@ -149,6 +158,10 @@ def on_model_update_start( self._val_collate_fn = val_dataset_collate_fn self._task_id = task_id self._model.add_task_params(task_id=self._task_id) + if self._mask_unused_classes: + # The first forward prop will populate the _class_mask with the following + # unique classes + self._classes_in_current_task = unique_classes(self._train_dataset) def train_dataloader(self) -> DataLoader: """Returns the dataloader for training the model.""" @@ -192,6 +205,9 @@ def training_step( """PyTorch Lightning function to return the training loss.""" inputs, targets = self.training_step_unpack_batch(batch) outputs = self(inputs) + outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs + ) intermediate_representation = self._model.get_intermediate_representation() self._model.reset_intermediate_representation_cache() loss = self._loss_fn(outputs, targets).mean() @@ -327,6 +343,7 @@ def __init__( test_target_transform: Optional[Callable] = None, logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, seed: int = defaults.SEED, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ) -> None: super().__init__( model=model, @@ -337,6 +354,7 @@ def __init__( batch_size=batch_size, logged_metrics=logged_metrics, seed=seed, + mask_unused_classes=mask_unused_classes, ) self._train_transform = train_transform self._train_target_transform = train_target_transform @@ -482,3 +500,19 @@ def save(self, output_state_dir: str) -> None: def load(self, input_state_dir: str) -> None: super().load(input_state_dir) self._memory_buffer.load(os.path.join(input_state_dir, "memory_buffer")) + + def on_model_update_start( + self, + train_dataset: Dataset, + val_dataset: Dataset, + train_dataset_collate_fn: Optional[Callable] = None, + val_dataset_collate_fn: Optional[Callable] = None, + task_id: Optional[str] = None, + ) -> None: + super().on_model_update_start( + train_dataset, val_dataset, train_dataset_collate_fn, val_dataset_collate_fn, task_id + ) + if self._mask_unused_classes: + self._classes_in_current_task = self._classes_in_current_task.union( + unique_classes(self._memory_buffer) + ) diff --git a/src/renate/updaters/model_updater.py b/src/renate/updaters/model_updater.py index 6d35677e..066f9aef 100644 --- a/src/renate/updaters/model_updater.py +++ b/src/renate/updaters/model_updater.py @@ -276,10 +276,12 @@ def __init__( deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, ): self._learner_kwargs = learner_kwargs or {} self._learner_kwargs["loss_fn"] = loss_fn self._learner_kwargs["optimizer"] = optimizer + self._learner_kwargs["mask_unused_classes"] = mask_unused_classes if learning_rate_scheduler is not None: self._learner_kwargs["learning_rate_scheduler"] = learning_rate_scheduler self._learner_kwargs[ diff --git a/src/renate/utils/misc.py b/src/renate/utils/misc.py index 3e12ab5f..e2554453 100644 --- a/src/renate/utils/misc.py +++ b/src/renate/utils/misc.py @@ -1,6 +1,10 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Union +from typing import Optional, Set, Union + +import torch + +from renate.utils.pytorch import complementary_indices def int_or_str(x: str) -> Union[str, int]: @@ -12,3 +16,24 @@ def int_or_str(x: str) -> Union[str, int]: return int(x) except ValueError: return x + + +def maybe_populate_mask_and_ignore_logits( + use_masking: bool, + class_mask: Optional[torch.Tensor], + classes_in_current_task: Optional[Set[int]], + logits: torch.Tensor, +): + """Snippet to compute which logits to ignore after computing the class mask if required.""" + if use_masking: + if class_mask is None: + # Now is the time to repopulate the class_mask + class_mask = torch.tensor( + complementary_indices(logits.size(1), classes_in_current_task), + device=logits.device, + dtype=torch.long, + ) + # fill the logits with -inf + logits.index_fill_(1, class_mask.to(logits.device), -float("inf")) + + return logits, class_mask diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 8f78e037..30a5738d 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Set, Tuple, Union import torch from torch.utils.data import Dataset, random_split @@ -125,3 +125,34 @@ def cat_nested_tensors( key: cat_nested_tensors([nested_tensor[key] for nested_tensor in nested_tensors], axis) for key in nested_tensors[0] } + + +def unique_classes(dataset: torch.utils.data.Dataset) -> Set[int]: + """Compute the unique class ids in a dataset. + + Args: + dataset: Instance of Torch dataset. + """ + from renate.memory.buffer import DataBuffer # to avoid circular import + + if isinstance(dataset, DataBuffer): + label_element = lambda elem: elem[0][1] + else: + label_element = lambda elem: elem[1] + + unique_values = set() + for ind in range(len(dataset)): + label = label_element(dataset[ind]) + unique_values.add(label.item()) + + return unique_values + + +def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int]: + """Compute the asymmetric difference between the two arguments + + Args: + num_outputs: An integer of total number of classes the model can output. + valid_classes: A set of integers of valid classes. + """ + return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] diff --git a/test/renate/utils/test_misc.py b/test/renate/utils/test_misc.py index d02d248e..601c9524 100644 --- a/test/renate/utils/test_misc.py +++ b/test/renate/utils/test_misc.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 import pytest +import torch -from renate.utils.misc import int_or_str +from renate.utils.misc import int_or_str, maybe_populate_mask_and_ignore_logits @pytest.mark.parametrize( @@ -17,3 +18,54 @@ ) def test_int_or_str(data_type, target): assert int_or_str(data_type) == target + + +@pytest.mark.parametrize( + "use_masking, class_mask, classes_in_task, logits, correct_output", + [ + [False, None, None, 0.5 * torch.ones((1, 5)), (0.5 * torch.ones((1, 5)), None)], + [ + True, + None, + {0, 1, 2}, + 0.5 * torch.ones((1, 5)), + ( + torch.tensor([[0.5000, 0.5000, 0.5000, -float("inf"), -float("inf")]]), + torch.tensor([3, 4]), + ), + ], + [ + True, + torch.tensor([3, 4]), + {0, 1, 2}, + 0.5 * torch.ones((1, 5)), + ( + torch.tensor([[0.5000, 0.5000, 0.5000, -float("inf"), -float("inf")]]), + torch.tensor([3, 4]), + ), + ], + ], +) +def test_possibly_populate_mask_and_ignore_logits( + use_masking, class_mask, classes_in_task, logits, correct_output +): + out_logits, out_cm = maybe_populate_mask_and_ignore_logits( + use_masking=use_masking, + class_mask=class_mask, + classes_in_current_task=classes_in_task, + logits=logits, + ) + + # There are a few cases to test. Besides the obvious ones, we also check if the function is a + # no-op for some parameters. For eg: when a class_mask is provided, it should be returned as is + # All operations on logits should be inplace. Both are accomplished through comparing + # data_ptr(). + assert torch.equal(out_logits, correct_output[0]) + if out_cm is None: + assert out_cm == correct_output[1] + else: + assert torch.equal(out_cm, correct_output[1]) + + assert out_logits.data_ptr() == logits.data_ptr() + if class_mask is not None: + assert class_mask.data_ptr() == out_cm.data_ptr() diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index b247976a..f96bcfcd 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -1,12 +1,22 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import numpy as np import pytest import torch import torchvision from torch.utils.data import TensorDataset +from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule +from renate.benchmark.scenarios import ClassIncrementalScenario +from renate.memory.buffer import ReservoirBuffer from renate.utils import pytorch -from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors, randomly_split_data +from renate.utils.pytorch import ( + cat_nested_tensors, + complementary_indices, + get_length_nested_tensors, + randomly_split_data, + unique_classes, +) @pytest.mark.parametrize("model", [torchvision.models.resnet18(pretrained=True)]) @@ -102,3 +112,41 @@ def test_cat_nested_tensors_wrong_shape(): cat_nested_tensors(((tensor1, tensor1), (tensor1, tensor2))) with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): cat_nested_tensors(({"k1": tensor1, "k2": tensor1}, {"k1": tensor1, "k2": tensor2})) + + +@pytest.mark.parametrize( + "num_outputs, indices, expected_output", + [ + [5, {2, 4}, [0, 1, 3]], + [torch.rand(5, 5).size(1), {1, 2, 3}, [0, 4]], + [torch.rand(5, 5).shape[1], {1, 2, 3}, [0, 4]], + ], +) +def test_complementary_indices(num_outputs, indices, expected_output): + assert expected_output == complementary_indices(num_outputs, indices) + + +@pytest.mark.parametrize("test_dataset", [True, False]) +def test_unique_classes(tmpdir, test_dataset): + if test_dataset: + class_groupings = np.arange(0, 100).reshape(10, 10).tolist() + data_module = TorchVisionDataModule(tmpdir, dataset_name="CIFAR100", val_size=0.2) + data_module.prepare_data() + for chunk_id in range(len(class_groupings)): + scenario = ClassIncrementalScenario( + data_module=data_module, groupings=class_groupings, chunk_id=chunk_id + ) + scenario.setup() + ds = scenario.val_data() + predicted_unique = unique_classes(ds) + + assert predicted_unique == set(class_groupings[chunk_id]) + else: + X = torch.randn(10, 3) + y = torch.arange(0, 10) + ds = torch.utils.data.TensorDataset(X, y) + metadata = {"foo": torch.ones(10)} + buffer = ReservoirBuffer(X.shape[0]) + buffer.update(ds, metadata) + predicted_unique = unique_classes(buffer) + assert predicted_unique == set(list(range(10))) From 593cff3df06d008ee49572194136a3bf5c1e44a7 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Wed, 16 Aug 2023 17:34:14 +0200 Subject: [PATCH 73/89] Modifies current text transformer implementation to a RenateBenchmarkingModule (#380) --- requirements.txt | 2 +- src/renate/benchmark/models/transformer.py | 61 ++++++++++++------- .../benchmark/models/test_text_transformer.py | 28 +++++++++ 3 files changed, 69 insertions(+), 22 deletions(-) create mode 100644 test/renate/benchmark/models/test_text_transformer.py diff --git a/requirements.txt b/requirements.txt index a8121e0e..517ccc5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ torchmetrics>=0.11.0, <0.11.5 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 datasets>=2.9.0, <2.14.1 -transformers>=4.30.0, <4.30.2 +transformers>=4.31.0, <4.31.1 scipy>=1.9.0, <1.11.2 diff --git a/src/renate/benchmark/models/transformer.py b/src/renate/benchmark/models/transformer.py index 71ae33d3..d1909233 100644 --- a/src/renate/benchmark/models/transformer.py +++ b/src/renate/benchmark/models/transformer.py @@ -1,39 +1,58 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Optional +from typing import Optional -import torch -from torch import Tensor -from transformers import AutoModelForSequenceClassification +from transformers import AutoModelForTextEncoding, PreTrainedModel -from renate.models import RenateModule +from renate.benchmark.models.base import RenateBenchmarkingModule +from renate.models.prediction_strategies import PredictionStrategy -class HuggingFaceSequenceClassificationTransformer(RenateModule): - """RenateModule which wraps around Hugging Face transformers. +class FeatureExtractorTextTransformer(PreTrainedModel): + """This is a facade class to extract the correct output from the transformer model.""" + + def __init__(self, pretrained_model_name: str): + model = AutoModelForTextEncoding.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name + ) + super().__init__(model.config) + self._model = model + + def forward(self, x): + out = self._model(**x, return_dict=True) + if hasattr(out, "pooler_output"): + return out.pooler_output + else: + return out.last_hidden_state[:, 0] # 0th element is used for classification. + + +class HuggingFaceSequenceClassificationTransformer(RenateBenchmarkingModule): + """RenateBenchmarkingModule which wraps around Hugging Face transformers. Args: pretrained_model_name: Hugging Face model id. num_outputs: Number of outputs. + prediction_strategy: Continual learning strategies may alter the prediction at train or test + time. + add_icarl_class_means: If ``True``, additional parameters used only by the + ``ICaRLModelUpdater`` are added. Only required when using that updater. """ def __init__( self, pretrained_model_name: str, - num_outputs: int, - ) -> None: + num_outputs: int = 10, + prediction_strategy: Optional[PredictionStrategy] = None, + add_icarl_class_means: bool = True, + ): + model = FeatureExtractorTextTransformer(pretrained_model_name=pretrained_model_name) + constructor_args = dict(pretrained_model_name=pretrained_model_name) super().__init__( - constructor_arguments={ - "pretrained_model_name": pretrained_model_name, - "num_outputs": num_outputs, - }, - ) - self._model = AutoModelForSequenceClassification.from_pretrained( - pretrained_model_name, num_labels=num_outputs, return_dict=False + embedding_size=model.config.hidden_size, + num_outputs=num_outputs, + constructor_arguments=constructor_args, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, ) - def forward(self, x: Dict[str, Tensor], task_id: Optional[str] = None) -> torch.Tensor: - return self._model(**x)[0] - - def _add_task_params(self, task_id: str) -> None: - assert not len(self._tasks_params_ids), "Transformer does not work for multiple tasks." + self._backbone = model diff --git a/test/renate/benchmark/models/test_text_transformer.py b/test/renate/benchmark/models/test_text_transformer.py new file mode 100644 index 00000000..d9426fe4 --- /dev/null +++ b/test/renate/benchmark/models/test_text_transformer.py @@ -0,0 +1,28 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer + + +@pytest.mark.parametrize("model_name", ["distilbert-base-uncased", "bert-base-uncased"]) +def test_init(model_name): + HuggingFaceSequenceClassificationTransformer(pretrained_model_name=model_name, num_outputs=10) + + +@pytest.mark.parametrize( + "model_name,input_dim", + [ + ["distilbert-base-uncased", (128,)], + ["bert-base-uncased", (256,)], + ], +) +def test_text_transformer_fwd(model_name, input_dim): + transformer = HuggingFaceSequenceClassificationTransformer(pretrained_model_name=model_name) + + x = {"input_ids": torch.randint(0, 30000, (5, *input_dim))} + y_hat = transformer(x) + + assert y_hat.shape[0] == 5 + assert y_hat.shape[1] == 10 From 73df5c2bc6427bdc78e6d766e043b64d92c26cd9 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 17 Aug 2023 11:09:16 +0200 Subject: [PATCH 74/89] Replace memory batch size with a fraction of the total batch size (#359) --- .../updaters/offline-er-clear10.json | 4 +- .../updaters/offline-er-clear100.json | 4 +- .../updaters/offline-er-domainnet.json | 4 +- .../class_incremental_learning_cifar10_der.py | 4 +- examples/nlp_finetuning/start.py | 4 +- .../start_with_hpo.py | 2 +- .../start_without_hpo.py | 4 +- .../start_training_with_er_without_hpo.py | 4 +- .../start_training_without_hpo.py | 4 +- src/renate/cli/parsing_functions.py | 18 ++++---- src/renate/defaults.py | 1 + src/renate/updaters/avalanche/learner.py | 4 +- .../updaters/avalanche/model_updater.py | 4 +- src/renate/updaters/experimental/er.py | 20 ++++----- src/renate/updaters/experimental/gdumb.py | 4 +- .../updaters/experimental/offline_er.py | 11 +---- src/renate/updaters/learner.py | 17 ++++--- src/renate/utils/config_spaces.py | 4 +- test/conftest.py | 6 +-- .../updaters/avalanche-er-buffer500.json | 3 +- .../configs/updaters/cls-er-buffer500.json | 4 +- .../configs/updaters/der-buffer500.json | 4 +- .../configs/updaters/er-buffer500.json | 4 +- .../configs/updaters/gdumb-buffer500.json | 4 +- .../updaters/offline-er-buffer500.json | 4 +- .../configs/updaters/pod-er-buffer500.json | 4 +- .../configs/updaters/super-er-buffer500.json | 4 +- .../avalanche/test_avalanche_learner.py | 10 ++++- .../avalanche/test_avalanche_model_updater.py | 18 ++++---- test/renate/updaters/experimental/test_er.py | 44 ++++++++++--------- 30 files changed, 120 insertions(+), 106 deletions(-) diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear10.json b/benchmarks/experiment_configs/updaters/offline-er-clear10.json index 6fe89caf..b1714bff 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-clear10.json +++ b/benchmarks/experiment_configs/updaters/offline-er-clear10.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 128, - "memory_batch_size": 128, + "batch_size": 256, + "batch_memory_frac": 0.5, "memory_size": 3300 } diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear100.json b/benchmarks/experiment_configs/updaters/offline-er-clear100.json index a6d09005..e91e4fad 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-clear100.json +++ b/benchmarks/experiment_configs/updaters/offline-er-clear100.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 128, - "memory_batch_size": 128, + "batch_size": 256, + "batch_memory_frac": 0.5, "memory_size": 10000 } diff --git a/benchmarks/experiment_configs/updaters/offline-er-domainnet.json b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json index b336094f..d403b96c 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-domainnet.json +++ b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 3450 } diff --git a/examples/benchmarking/class_incremental_learning_cifar10_der.py b/examples/benchmarking/class_incremental_learning_cifar10_der.py index b9b8a67e..8b5c67e4 100644 --- a/examples/benchmarking/class_incremental_learning_cifar10_der.py +++ b/examples/benchmarking/class_incremental_learning_cifar10_der.py @@ -11,8 +11,8 @@ "learning_rate": 0.03, "alpha": 0.2, "beta": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 500, "max_epochs": 50, "loss_normalization": 0, diff --git a/examples/nlp_finetuning/start.py b/examples/nlp_finetuning/start.py index cb6a8951..932d9dfe 100644 --- a/examples/nlp_finetuning/start.py +++ b/examples/nlp_finetuning/start.py @@ -12,8 +12,8 @@ "weight_decay": 0.0, "learning_rate": 0.001, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 300, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/examples/simple_classifier_cifar10/start_with_hpo.py b/examples/simple_classifier_cifar10/start_with_hpo.py index 16fb9e2e..eac54342 100644 --- a/examples/simple_classifier_cifar10/start_with_hpo.py +++ b/examples/simple_classifier_cifar10/start_with_hpo.py @@ -13,7 +13,7 @@ "learning_rate": loguniform(1e-4, 1e-1), "alpha": uniform(0.0, 1.0), "batch_size": choice([32, 64, 128, 256]), - "memory_batch_size": 32, + "batch_memory_frac": 0.5, "memory_size": 1000, "loss_normalization": 0, "loss_weight": uniform(0.0, 1.0), diff --git a/examples/simple_classifier_cifar10/start_without_hpo.py b/examples/simple_classifier_cifar10/start_without_hpo.py index f7f408f3..177002e6 100644 --- a/examples/simple_classifier_cifar10/start_without_hpo.py +++ b/examples/simple_classifier_cifar10/start_without_hpo.py @@ -11,8 +11,8 @@ "weight_decay": 0.0, "learning_rate": 0.1, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 300, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/examples/train_mlp_locally/start_training_with_er_without_hpo.py b/examples/train_mlp_locally/start_training_with_er_without_hpo.py index 6031f107..b9e5f4f6 100644 --- a/examples/train_mlp_locally/start_training_with_er_without_hpo.py +++ b/examples/train_mlp_locally/start_training_with_er_without_hpo.py @@ -7,9 +7,9 @@ "momentum": 0.0, "weight_decay": 1e-2, "learning_rate": 0.05, - "batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "max_epochs": 50, - "memory_batch_size": 32, "memory_size": 500, } diff --git a/examples/train_mlp_locally/start_training_without_hpo.py b/examples/train_mlp_locally/start_training_without_hpo.py index f8d20dd3..e58c264f 100644 --- a/examples/train_mlp_locally/start_training_without_hpo.py +++ b/examples/train_mlp_locally/start_training_without_hpo.py @@ -9,8 +9,8 @@ "weight_decay": 0.0, "learning_rate": 0.1, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 500, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 4dae8f78..fe4f33cd 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -53,7 +53,7 @@ def get_updater_and_learner_kwargs( "loss_weight", "ema_memory_update_gamma", "memory_size", - "memory_batch_size", + "batch_memory_frac", "loss_normalization", ] updater_class = None @@ -93,7 +93,7 @@ def get_updater_and_learner_kwargs( ] updater_class = SuperExperienceReplayModelUpdater elif args.updater == "Offline-ER": - learner_args = learner_args + ["loss_weight_new_data", "memory_size", "memory_batch_size"] + learner_args = learner_args + ["loss_weight_new_data", "memory_size", "batch_memory_frac"] updater_class = OfflineExperienceReplayModelUpdater elif args.updater == "RD": learner_args = learner_args + ["memory_size"] @@ -108,7 +108,7 @@ def get_updater_and_learner_kwargs( learner_args = learner_args updater_class = FineTuningModelUpdater elif args.updater == "Avalanche-ER": - learner_args = learner_args + ["memory_size", "memory_batch_size"] + learner_args = learner_args + ["memory_size", "batch_memory_frac"] from renate.updaters.avalanche.model_updater import ExperienceReplayAvalancheModelUpdater updater_class = ExperienceReplayAvalancheModelUpdater @@ -123,7 +123,7 @@ def get_updater_and_learner_kwargs( updater_class = LearningWithoutForgettingModelUpdater elif args.updater == "Avalanche-iCaRL": - learner_args = learner_args + ["memory_size", "memory_batch_size"] + learner_args = learner_args + ["memory_size", "batch_memory_frac"] from renate.updaters.avalanche.model_updater import ICaRLModelUpdater updater_class = ICaRLModelUpdater @@ -428,11 +428,11 @@ def _add_replay_learner_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: "help": "Memory size available for the memory buffer. Default: " f"{defaults.MEMORY_SIZE}.", }, - "memory_batch_size": { - "type": int, - "default": defaults.BATCH_SIZE, - "help": "Batch size used during model update for the memory buffer. Default: " - f"{defaults.BATCH_SIZE}.", + "batch_memory_frac": { + "type": float, + "default": defaults.BATCH_MEMORY_FRAC, + "help": "Fraction of the batch populated with memory data. Default: " + f"{defaults.BATCH_MEMORY_FRAC}.", }, } ) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 9025140d..12ed2983 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -18,6 +18,7 @@ WEIGHT_DECAY = 0.0 MAX_EPOCHS = 50 BATCH_SIZE = 32 +BATCH_MEMORY_FRAC = 0.5 LOSS_WEIGHT = 1.0 SEED = 0 EMA_MEMORY_UPDATE_GAMMA = 1.0 diff --git a/src/renate/updaters/avalanche/learner.py b/src/renate/updaters/avalanche/learner.py index ff258cae..d2e04a34 100644 --- a/src/renate/updaters/avalanche/learner.py +++ b/src/renate/updaters/avalanche/learner.py @@ -38,7 +38,7 @@ def update_settings( avalanche_learner._criterion = self._loss_fn avalanche_learner.train_epochs = max_epochs avalanche_learner.train_mb_size = self._batch_size - avalanche_learner.eval_mb_size = self._batch_size + avalanche_learner.eval_mb_size = self._batch_size + getattr(self, "_memory_batch_size", 0) avalanche_learner.device = device avalanche_learner.eval_every = eval_every @@ -57,7 +57,7 @@ def _create_avalanche_learner( optimizer=optimizer, criterion=self._loss_fn, train_mb_size=self._batch_size, - eval_mb_size=self._batch_size, + eval_mb_size=self._batch_size + getattr(self, "_memory_batch_size", 0), train_epochs=train_epochs, plugins=plugins, evaluator=default_evaluator(), diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index e5fb0933..10e1d2f4 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -259,7 +259,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -289,7 +289,7 @@ def __init__( learner_kwargs = { "batch_size": batch_size, "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "seed": seed, } super().__init__( diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index 82b9a1da..e9fe4018 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -532,7 +532,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -566,7 +566,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -614,7 +614,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -649,7 +649,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -698,7 +698,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -734,7 +734,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -784,7 +784,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -823,7 +823,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -876,7 +876,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -921,7 +921,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index 623046a8..03e12f6f 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -106,7 +106,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -136,7 +136,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, "seed": seed, } diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 78ac2dd0..7809af1f 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -29,18 +29,11 @@ class OfflineExperienceReplayLearner(ReplayLearner): terminated. Args: - memory_size: The maximum size of the memory. - memory_batch_size: Size of batches sampled from the memory. The memory batch will be - appended to the batch sampled from the current dataset, leading to an effective batch - size of `memory_batch_size + batch_size`. loss_weight_new_data: The training loss will be a convex combination of the loss on the new data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here, it will be used as the weight for the new data. If `None`, the weight will be set dynamically to `N_t / sum([N_1, ..., N_t])`, where `N_i` denotes the size of task/chunk `i` and the current task is `t`. - buffer_transform: The transformation to be applied to the memory buffer data samples. - buffer_target_transform: The target transformation to be applied to the memory buffer target - samples. """ def __init__(self, loss_weight_new_data: Optional[float] = None, **kwargs) -> None: @@ -147,7 +140,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight_new_data: Optional[float] = None, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 @@ -178,7 +171,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight_new_data": loss_weight_new_data, "batch_size": batch_size, "seed": seed, diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index fd9d4123..d7e8777f 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -456,9 +456,7 @@ class ReplayLearner(Learner, abc.ABC): Args: memory_size: The maximum size of the memory. - memory_batch_size: Size of batches sampled from the memory. The memory batch will be - appended to the batch sampled from the current dataset, leading to an effective batch - size of `memory_batch_size + batch_size`. + batch_memory_frac: Fraction of the batch that is sampled from rehearsal memory. buffer_transform: The transformation to be applied to the memory buffer data samples. buffer_target_transform: The target transformation to be applied to the memory buffer target samples. @@ -468,14 +466,21 @@ class ReplayLearner(Learner, abc.ABC): def __init__( self, memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, seed: int = defaults.SEED, **kwargs, ) -> None: - super().__init__(seed=seed, **kwargs) - self._memory_batch_size = min(memory_size, memory_batch_size) + if not (0 <= batch_memory_frac <= 1): + raise ValueError( + f"Expecting batch_memory_frac to be in [0, 1], received {batch_memory_frac}." + ) + memory_batch_size = min(memory_size, int(batch_memory_frac * batch_size)) + batch_size = batch_size - memory_batch_size + super().__init__(batch_size=batch_size, seed=seed, **kwargs) + self._memory_batch_size = memory_batch_size self._memory_buffer = ReservoirBuffer( max_size=memory_size, seed=seed, diff --git a/src/renate/utils/config_spaces.py b/src/renate/utils/config_spaces.py index 2c1f6bc6..d0842049 100644 --- a/src/renate/utils/config_spaces.py +++ b/src/renate/utils/config_spaces.py @@ -15,13 +15,13 @@ def _get_range(start, end, step): "momentum": choice([0.0, 0.9, 0.99]), "weight_decay": loguniform(1e-6, 1e-2), "learning_rate": loguniform(0.001, 0.5), - "batch_size": 32, + "batch_size": 64, "max_epochs": 50, } _replay_config_space = { **_learner_config_space, **{ - "memory_batch_size": 32, + "batch_memory_frac": 0.5, "memory_size": 1000, }, } diff --git a/test/conftest.py b/test/conftest.py index 921f6540..9d1fd5ce 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -71,7 +71,7 @@ def pytest_collection_modifyitems(config, items): LEARNER_KWARGS = { ExperienceReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, }, @@ -81,7 +81,7 @@ def pytest_collection_modifyitems(config, items): RepeatedDistillationLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, OfflineExperienceReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "loss_weight_new_data": 0.5, "batch_size": 50, "seed": 1, @@ -90,7 +90,7 @@ def pytest_collection_modifyitems(config, items): AVALANCHE_LEARNER_KWARGS = { AvalancheReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, }, diff --git a/test/integration_tests/configs/updaters/avalanche-er-buffer500.json b/test/integration_tests/configs/updaters/avalanche-er-buffer500.json index a0cc0afe..d50ca947 100644 --- a/test/integration_tests/configs/updaters/avalanche-er-buffer500.json +++ b/test/integration_tests/configs/updaters/avalanche-er-buffer500.json @@ -4,6 +4,7 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, + "batch_size": 288, + "batch_memory_frac": 0.112, "memory_size": 500 } diff --git a/test/integration_tests/configs/updaters/cls-er-buffer500.json b/test/integration_tests/configs/updaters/cls-er-buffer500.json index a62b367a..a54b08b4 100644 --- a/test/integration_tests/configs/updaters/cls-er-buffer500.json +++ b/test/integration_tests/configs/updaters/cls-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.5, "beta": 0.1, diff --git a/test/integration_tests/configs/updaters/der-buffer500.json b/test/integration_tests/configs/updaters/der-buffer500.json index 13dea96b..6bf12918 100644 --- a/test/integration_tests/configs/updaters/der-buffer500.json +++ b/test/integration_tests/configs/updaters/der-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.2, "beta": 0.5, diff --git a/test/integration_tests/configs/updaters/er-buffer500.json b/test/integration_tests/configs/updaters/er-buffer500.json index d23103ee..b0645044 100644 --- a/test/integration_tests/configs/updaters/er-buffer500.json +++ b/test/integration_tests/configs/updaters/er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.5, "loss_normalization": 0, diff --git a/test/integration_tests/configs/updaters/gdumb-buffer500.json b/test/integration_tests/configs/updaters/gdumb-buffer500.json index efa07b0c..9d0d3ef0 100644 --- a/test/integration_tests/configs/updaters/gdumb-buffer500.json +++ b/test/integration_tests/configs/updaters/gdumb-buffer500.json @@ -4,7 +4,7 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500 } diff --git a/test/integration_tests/configs/updaters/offline-er-buffer500.json b/test/integration_tests/configs/updaters/offline-er-buffer500.json index eaa22099..f080a295 100644 --- a/test/integration_tests/configs/updaters/offline-er-buffer500.json +++ b/test/integration_tests/configs/updaters/offline-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "loss_weight_new_data": 0.5 } diff --git a/test/integration_tests/configs/updaters/pod-er-buffer500.json b/test/integration_tests/configs/updaters/pod-er-buffer500.json index c46f3f15..76ae14af 100644 --- a/test/integration_tests/configs/updaters/pod-er-buffer500.json +++ b/test/integration_tests/configs/updaters/pod-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 1.0, "distillation_type": "spatial", diff --git a/test/integration_tests/configs/updaters/super-er-buffer500.json b/test/integration_tests/configs/updaters/super-er-buffer500.json index d3dafffc..371d92ee 100644 --- a/test/integration_tests/configs/updaters/super-er-buffer500.json +++ b/test/integration_tests/configs/updaters/super-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "der_alpha": 1.0, "der_beta": 1.0, diff --git a/test/renate/updaters/avalanche/test_avalanche_learner.py b/test/renate/updaters/avalanche/test_avalanche_learner.py index 6e0e3eda..e5657e2c 100644 --- a/test/renate/updaters/avalanche/test_avalanche_learner.py +++ b/test/renate/updaters/avalanche/test_avalanche_learner.py @@ -30,6 +30,7 @@ def check_learner_settings( expected_max_epochs, expected_device, expected_eval_every, + expected_batch_size, expected_loss_fn=None, ): if isinstance(learner, AvalancheICaRLLearner): @@ -47,7 +48,7 @@ def check_learner_settings( assert avalanche_learner._criterion == expected_loss_fn assert avalanche_learner.optimizer == expected_optimizer assert avalanche_learner.train_epochs == expected_max_epochs - assert avalanche_learner.train_mb_size == learner_kwargs["batch_size"] + assert avalanche_learner.train_mb_size == expected_batch_size assert avalanche_learner.eval_mb_size == learner_kwargs["batch_size"] assert avalanche_learner.device == expected_device @@ -84,6 +85,11 @@ def test_update_settings(learner_class): expected_optimizer = SGD(expected_model.parameters(), lr=0.1) expected_device = torch.device("cpu") expected_eval_every = -1 + expected_batch_size = learner_kwargs["batch_size"] + if "batch_memory_frac" in learner_kwargs: + expected_batch_size = expected_batch_size - int( + learner_kwargs["batch_memory_frac"] * expected_batch_size + ) learner = learner_class( model=expected_model, optimizer=None, @@ -107,6 +113,7 @@ def test_update_settings(learner_class): expected_max_epochs=expected_max_epochs, expected_device=expected_device, expected_eval_every=expected_eval_every, + expected_batch_size=expected_batch_size, ) # Update @@ -135,4 +142,5 @@ def test_update_settings(learner_class): expected_max_epochs=expected_max_epochs, expected_device=expected_device, expected_eval_every=expected_eval_every, + expected_batch_size=expected_batch_size, ) diff --git a/test/renate/updaters/avalanche/test_avalanche_model_updater.py b/test/renate/updaters/avalanche/test_avalanche_model_updater.py index 2738c733..af2dad9c 100644 --- a/test/renate/updaters/avalanche/test_avalanche_model_updater.py +++ b/test/renate/updaters/avalanche/test_avalanche_model_updater.py @@ -77,15 +77,17 @@ def test_continuation_of_training_with_avalanche_model_updater(tmpdir, learner_c @pytest.mark.parametrize( - "batch_size,memory_size,memory_batch_size", - [[10, 10, 10], [20, 10, 10], [10, 100, 10], [10, 30, 1], [100, 10, 3]], + "batch_size,memory_size,batch_memory_frac", + [[20, 10, 0.5], [30, 10, 0.34], [20, 100, 0.5], [10, 30, 0.1], [100, 10, 0.03]], ) -def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_batch_size): +def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, batch_memory_frac): + expected_memory_batch_size = int(batch_memory_frac * batch_size) + expected_batch_size = batch_size - expected_memory_batch_size dataset_size = 100 model, dataset = get_model_and_dataset(dataset_size) learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, } model_updater = ExperienceReplayAvalancheModelUpdater( @@ -99,9 +101,9 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b ) model_updater.update(train_dataset=dataset) replay_plugin = plugin_by_class(ReplayPlugin, model_updater._learner.plugins) - assert replay_plugin.batch_size == batch_size + assert replay_plugin.batch_size == expected_batch_size assert replay_plugin.mem_size == memory_size - assert replay_plugin.batch_size_mem == memory_batch_size + assert replay_plugin.batch_size_mem == expected_memory_batch_size assert len(replay_plugin.storage_policy.buffer) == min( memory_size, dataset_size, len(replay_plugin.storage_policy.buffer) ) @@ -121,9 +123,9 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b ) replay_plugin = plugin_by_class(ReplayPlugin, model_updater._learner.plugins) - assert replay_plugin.batch_size == batch_size + assert replay_plugin.batch_size == expected_batch_size assert replay_plugin.mem_size == memory_size - assert replay_plugin.batch_size_mem == memory_batch_size + assert replay_plugin.batch_size_mem == expected_memory_batch_size model_updater.update(train_dataset=dataset) assert len(model_updater._learner.dataloader.data) == dataset_size assert len(model_updater._learner.dataloader.memory) == min( diff --git a/test/renate/updaters/experimental/test_er.py b/test/renate/updaters/experimental/test_er.py index 72545b72..de5d6ad9 100644 --- a/test/renate/updaters/experimental/test_er.py +++ b/test/renate/updaters/experimental/test_er.py @@ -26,14 +26,15 @@ def get_model_and_dataset(): @pytest.mark.parametrize( - "batch_size,memory_size,memory_batch_size", - [[10, 10, 10], [20, 10, 10], [10, 100, 10], [10, 30, 1], [100, 10, 3]], + "batch_size,memory_size,batch_memory_frac", + [[20, 10, 0.5], [30, 10, 0.34], [20, 100, 0.5], [10, 30, 0.1], [100, 10, 0.03]], ) -def test_er_overall_memory_size_after_update(batch_size, memory_size, memory_batch_size): +def test_er_overall_memory_size_after_update(batch_size, memory_size, batch_memory_frac): + memory_batch_size = int(batch_memory_frac * batch_size) model, dataset = get_model_and_dataset() learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, } model_updater = pytest.helpers.get_simple_updater( @@ -88,9 +89,15 @@ def test_er_validation_buffer(tmpdir): ) +def validate_common_args(model_updater, learner_kwargs): + memory_batch_size = int(learner_kwargs["batch_memory_frac"] * learner_kwargs["batch_size"]) + batch_size = learner_kwargs["batch_size"] - memory_batch_size + assert model_updater._learner._batch_size == batch_size + assert model_updater._learner._memory_batch_size == memory_batch_size + + def validate_cls_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["alpha"] assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["beta"] assert ( @@ -112,15 +119,13 @@ def validate_cls_er(model_updater, learner_kwargs): def validate_dark_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["beta"] assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["alpha"] def validate_pod_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["pod_loss"].weight == learner_kwargs["alpha"] assert ( model_updater._learner._components["pod_loss"]._distillation_type @@ -130,8 +135,7 @@ def validate_pod_er(model_updater, learner_kwargs): def validate_super_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["der_beta"] assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["der_alpha"] assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["cls_alpha"] @@ -174,7 +178,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_cls_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -186,7 +190,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 2.3, @@ -202,7 +206,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_dark_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -210,7 +214,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 2.3, @@ -222,7 +226,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_pod_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -231,7 +235,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 0.123, @@ -244,7 +248,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_super_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "der_alpha": 0.123, @@ -262,7 +266,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "der_alpha": 2.3, From 2ca8df277430981730b45ecf8a1cac40cc23472b Mon Sep 17 00:00:00 2001 From: Lukas Balles Date: Fri, 18 Aug 2023 09:58:25 +0200 Subject: [PATCH 75/89] Make offline ER us total batch size in first update (#381) --- src/renate/updaters/experimental/offline_er.py | 9 +++++++-- .../configs/suites/quick/offline-er.json | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 7809af1f..c541faea 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -71,9 +71,9 @@ def on_model_update_start( self._num_points_current_task = len(train_dataset) def train_dataloader(self) -> DataLoader: - train_loader = super().train_dataloader() - loaders = {"current_task": train_loader} + loaders = {} if len(self._memory_buffer) > self._memory_batch_size: + loaders["current_task"] = super().train_dataloader() loaders["memory"] = DataLoader( dataset=self._memory_buffer, batch_size=self._memory_batch_size, @@ -83,6 +83,11 @@ def train_dataloader(self) -> DataLoader: pin_memory=True, collate_fn=self._train_collate_fn, ) + else: + batch_size = self._batch_size + self._batch_size += self._memory_batch_size + loaders["current_task"] = super().train_dataloader() + self._batch_size = batch_size return CombinedLoader(loaders, mode="max_size_cycle") def on_model_update_end(self) -> None: diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 2c62856a..3732b00e 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.7319999933242798, 0.4699999988079071], [0.7515000104904175, 0.49300000071525574]], - "expected_accuracy_darwin": [[0.7300000190734863, 0.5350000262260437]] + "expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]], + "expected_accuracy_darwin": [[0.7315000295639038, 0.49000000953674316]] } From 8b4bc59bd4d505934232f977dee35ffd06aa1429 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Mon, 21 Aug 2023 17:42:29 +0200 Subject: [PATCH 76/89] Learning to Prompt (#367) --- .../experiment_configs/models/bert.json | 2 +- doc/benchmarking/renate_benchmarks.rst | 2 +- doc/getting_started/supported_algorithms.rst | 6 + src/renate/benchmark/experiment_config.py | 75 +++-- src/renate/benchmark/models/__init__.py | 2 + src/renate/benchmark/models/base.py | 2 +- src/renate/benchmark/models/l2p.py | 312 ++++++++++++++++++ src/renate/benchmark/models/transformer.py | 17 +- .../benchmark/models/vision_transformer.py | 40 ++- src/renate/cli/parsing_functions.py | 32 +- src/renate/defaults.py | 3 + src/renate/updaters/experimental/l2p.py | 289 ++++++++++++++++ test/conftest.py | 25 +- test/renate/benchmark/models/test_l2p.py | 64 ++++ .../benchmark/models/test_text_transformer.py | 8 +- .../models/test_vision_transformer.py | 12 +- .../benchmark/test_experimentation_config.py | 25 +- test/renate/updaters/test_learner.py | 27 +- 18 files changed, 867 insertions(+), 76 deletions(-) create mode 100644 src/renate/benchmark/models/l2p.py create mode 100644 src/renate/updaters/experimental/l2p.py create mode 100644 test/renate/benchmark/models/test_l2p.py diff --git a/benchmarks/experiment_configs/models/bert.json b/benchmarks/experiment_configs/models/bert.json index 55f3ba6e..11103171 100644 --- a/benchmarks/experiment_configs/models/bert.json +++ b/benchmarks/experiment_configs/models/bert.json @@ -1,4 +1,4 @@ { "model_name": "HuggingFaceTransformer", - "pretrained_model_name": "bert-base-uncased" + "pretrained_model_name_or_path": "bert-base-uncased" } diff --git a/doc/benchmarking/renate_benchmarks.rst b/doc/benchmarking/renate_benchmarks.rst index 0daf12ee..f914eb84 100644 --- a/doc/benchmarking/renate_benchmarks.rst +++ b/doc/benchmarking/renate_benchmarks.rst @@ -95,7 +95,7 @@ The full list of models and model names including a short description is provide - * ``num_outputs``: Output dimensionality, for classification the number of classes. * - `~renate.benchmark.models.transformer.HuggingFaceSequenceClassificationTransformer` - Wrapper around Hugging Face transformers. - - * ``pretrained_model_name``: Hugging Face `transformer ID `__. + - * ``pretrained_model_name_or_path``: Hugging Face `transformer ID `__. * ``num_outputs``: The number of classes. diff --git a/doc/getting_started/supported_algorithms.rst b/doc/getting_started/supported_algorithms.rst index 8fd3e197..02c0ce98 100644 --- a/doc/getting_started/supported_algorithms.rst +++ b/doc/getting_started/supported_algorithms.rst @@ -36,6 +36,12 @@ using Renate (e.g., using :py:func:`~renate.training.training.run_training_job`; * - ``"FineTuning"`` - :py:class:`Learner ` - A simple method which trains the current model on only the new data without any sort of mitigation for forgetting. Used as "lower bound" baseline in experiments. + * - ``"LearningToPrompt"`` + - :py:class:`LearningToPromptLearner ` + - A class that implements a Learning to Prompt method for ViTs. The methods trains only the input prompts that are sampled from a prompt pool in an input dependent fashion. + * - ``"LearningToPromptReplay"`` + - :py:class:`LearningToPromptLearner ` + - A class that extends the Learning to Prompt method to use a memory replay method like "Offline-ER" * - ``"Avalanche-ER"`` - :py:class:`AvalancheReplayLearner ` - A wrapper which gives access to Experience Replay as implemented in the Avalanche library. This method is the equivalent to our Offline-ER. diff --git a/src/renate/benchmark/experiment_config.py b/src/renate/benchmark/experiment_config.py index 6069574e..16ba10b0 100644 --- a/src/renate/benchmark/experiment_config.py +++ b/src/renate/benchmark/experiment_config.py @@ -21,6 +21,7 @@ from renate.benchmark.datasets.wild_time_data import WildTimeDataModule from renate.benchmark.models import ( MultiLayerPerceptron, + LearningToPromptTransformer, ResNet18, ResNet18CIFAR, ResNet34, @@ -64,6 +65,7 @@ "VisionTransformerL32": VisionTransformerL32, "VisionTransformerH14": VisionTransformerH14, "HuggingFaceTransformer": HuggingFaceSequenceClassificationTransformer, + "LearningToPromptTransformer": LearningToPromptTransformer, } @@ -76,7 +78,7 @@ def model_fn( num_hidden_layers: Optional[int] = None, hidden_size: Optional[Tuple[int]] = None, dataset_name: Optional[str] = None, - pretrained_model_name: Optional[str] = None, + pretrained_model_name_or_path: Optional[str] = None, ) -> RenateModule: """Returns a model instance.""" if model_name not in models: @@ -98,7 +100,14 @@ def model_fn( elif model_name == "HuggingFaceTransformer": if updater == "Avalanche-iCaRL": raise ValueError("Transformers do not support iCaRL.") - model_kwargs["pretrained_model_name"] = pretrained_model_name + model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path + elif (updater is not None) and ("LearningToPrompt" in updater): + if not model_name.startswith("LearningToPrompt"): + raise ValueError( + "L2P model updaters are designed to work only with " + f"LearningToPromptTransformer, but model name specified is {model_name}." + ) + model_kwargs["pretrained_model_name_or_path"] = pretrained_model_name_or_path if model_state_url is None: model = model_class(**model_kwargs) else: @@ -114,13 +123,13 @@ def get_data_module( dataset_name: str, val_size: float, seed: int, - pretrained_model_name: Optional[str], + pretrained_model_name_or_path: Optional[str], input_column: Optional[str], target_column: Optional[str], ) -> RenateDataModule: tokenizer = None - if pretrained_model_name is not None: - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + if pretrained_model_name_or_path is not None and "vit" not in pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) if dataset_name in TorchVisionDataModule.dataset_dict: return TorchVisionDataModule( data_path, dataset_name=dataset_name, val_size=val_size, seed=seed @@ -136,7 +145,7 @@ def get_data_module( "val_size": val_size, "seed": seed, } - if pretrained_model_name is not None: + if pretrained_model_name_or_path is not None: data_module_kwargs["tokenizer"] = tokenizer return WildTimeDataModule(**data_module_kwargs) if dataset_name == "DomainNet": @@ -277,7 +286,7 @@ def data_module_fn( randomness: Optional[float] = None, src_bucket: Optional[str] = None, src_object_name: Optional[str] = None, - pretrained_model_name: Optional[str] = None, + pretrained_model_name_or_path: Optional[str] = None, input_column: Optional[str] = None, target_column: Optional[str] = None, data_ids: Optional[List[Union[int, str]]] = None, @@ -289,7 +298,7 @@ def data_module_fn( dataset_name=dataset_name, val_size=val_size, seed=seed, - pretrained_model_name=pretrained_model_name, + pretrained_model_name_or_path=pretrained_model_name_or_path, input_column=input_column, target_column=target_column, ) @@ -333,10 +342,8 @@ def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Opti if dataset_name == "fmow": return default_transform(dataset_name) if dataset_name == "yearbook": - if ( - model_name is not None - and model_name.startswith("VisionTransformer") - and model_name != "VisionTransformerCIFAR" + if (model_name is not None) and ( + ("Transformer" in model_name) and (model_name != "VisionTransformerCIFAR") ): return transforms.Compose( [ @@ -356,13 +363,25 @@ def train_transform(dataset_name: str, model_name: Optional[str] = None) -> Opti ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None if dataset_name in ["CIFAR10", "CIFAR100"]: - return transforms.Compose( - [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - _get_normalize_transform(dataset_name), - ] - ) + if (model_name is not None) and ( + ("Transformer" in model_name) and (model_name != "VisionTransformerCIFAR") + ): + return transforms.Compose( + [ + transforms.RandomResizedCrop( + 224, scale=(0.05, 1.0), ratio=(3.0 / 4.0, 4.0 / 3.0) + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + else: + return transforms.Compose( + [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + _get_normalize_transform(dataset_name), + ] + ) if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ @@ -383,10 +402,8 @@ def test_transform( if dataset_name == "fmow": return default_transform(dataset_name) if dataset_name == "yearbook": - if ( - model_name is not None - and model_name.startswith("VisionTransformer") - and model_name != "VisionTransformerCIFAR" + if (model_name is not None) and ( + ("Transformer" in model_name) and (model_name != "VisionTransformerCIFAR") ): return transforms.Compose( [ @@ -405,7 +422,17 @@ def test_transform( ] + wild_time_data.list_datasets() or dataset_name.startswith("hfd-"): return None if dataset_name in ["CIFAR10", "CIFAR100"]: - return _get_normalize_transform(dataset_name) + if (model_name is not None) and ( + ("Transformer" in model_name) and (model_name != "VisionTransformerCIFAR") + ): + return transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + ] + ) + else: + return _get_normalize_transform(dataset_name) if dataset_name in ["CLEAR10", "CLEAR100", "DomainNet"]: return transforms.Compose( [ diff --git a/src/renate/benchmark/models/__init__.py b/src/renate/benchmark/models/__init__.py index 0c537a0d..edfdf6bb 100644 --- a/src/renate/benchmark/models/__init__.py +++ b/src/renate/benchmark/models/__init__.py @@ -9,6 +9,7 @@ ResNet50, ResNet50CIFAR, ) +from renate.benchmark.models.l2p import LearningToPromptTransformer from renate.benchmark.models.vision_transformer import ( VisionTransformerB16, VisionTransformerB32, @@ -26,6 +27,7 @@ "ResNet34CIFAR", "ResNet50", "ResNet50CIFAR", + "LearningToPromptTransformer", "VisionTransformerB16", "VisionTransformerB32", "VisionTransformerCIFAR", diff --git a/src/renate/benchmark/models/base.py b/src/renate/benchmark/models/base.py index 30a5424e..474f7196 100644 --- a/src/renate/benchmark/models/base.py +++ b/src/renate/benchmark/models/base.py @@ -87,7 +87,7 @@ def get_extra_state(self, encode=True) -> Any: Encode converts the state into a torch tensor so that Deepspeed serialization works. We don't encode any of the super() calls, but encode only the final dict. """ - extra_state = super().get_extra_state(encode=not encode) + extra_state = super().get_extra_state(encode=False) extra_state["prediction_strategy"] = self._prediction_strategy return convert_to_tensor(extra_state) if encode else extra_state diff --git a/src/renate/benchmark/models/l2p.py b/src/renate/benchmark/models/l2p.py new file mode 100644 index 00000000..cd0f8379 --- /dev/null +++ b/src/renate/benchmark/models/l2p.py @@ -0,0 +1,312 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import functools +import logging +from typing import Callable, Optional, Union + +import torch +import torch.nn as nn + +from renate import defaults +from renate.benchmark.models.base import RenateBenchmarkingModule +from renate.benchmark.models.transformer import HuggingFaceSequenceClassificationTransformer +from renate.benchmark.models.vision_transformer import VisionTransformer +from renate.models.prediction_strategies import PredictionStrategy + + +logger = logging.getLogger(__name__) + + +class PromptPool(nn.Module): + """Implements the prompt pool for L2P + + Args: + pool_size: Total size of the prompt pool. + pool_selection_size: Number of prompts to select from the pool. + prompt_size: Number of tokens each prompt is equivalent to. + prompt_key_dim: Dimensions of the prompt key used to compute similarity. It has to be same + to the dimensions of `x` in forward. + embedding_dim: Output dimension of the token/patch embedding layer + train_prompt_keys: Whether to train the prompt keys. Currently unused. + similarity_fn: Similarity function between input features and prompt keys + per_batch_prompt: Flag to use the same prompts for all elements in the batch + """ + + def __init__( + self, + pool_size: int = 10, + pool_selection_size: int = 5, + prompt_size: int = 5, + prompt_key_dim: int = 768, + embedding_dim: int = 768, + train_prompt_keys: bool = True, + similarity_fn: Union[Callable, str] = "cosine", + per_batch_prompt: bool = True, + ): + super().__init__() + self._M = pool_size ## total pool size + self._N = pool_selection_size ## number of prompts selected per input + self._Lp = prompt_size ## each prompt is equal to how many tokens + self._d = embedding_dim ## + self._pd = prompt_key_dim + self._per_batch_prompt = per_batch_prompt + + self._parse_similarity_fn(similarity_fn) + self.train_prompt_keys = train_prompt_keys ## This is unused for now + self.prompt_pool = nn.Parameter(torch.empty((self._M, self._Lp, self._d)).uniform_(-1, 1)) + self.prompt_keys = nn.Parameter(torch.empty((self._M, self._pd)).uniform_(-1, 1)) + + self.key_hist = torch.zeros((self._M,), dtype=torch.float32) + + def _parse_similarity_fn(self, similarity_fn: Union[Callable, str]) -> None: + if callable(similarity_fn): + self.similarity_fn = similarity_fn + elif not isinstance(similarity_fn, str): + raise ValueError( + "similarity_fn has to be a callable or a string representing similarity metric. " + "But got {similarity_fn}" + ) + elif similarity_fn == "cosine": + normalization_fn = functools.partial(torch.nn.functional.normalize, p=2) + self.similarity_fn = lambda x, y: normalization_fn(x).matmul(normalization_fn(y).t()) + else: + raise ValueError( + f"Currently only cosine similarity is supported, but got {similarity_fn}" + ) + + def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTensor] = None): + """ + Args: + x: Image features extracted. It can be [CLS] token or something else of + dimension B x self.pd.. + manual_prompt_indices: Indices to manually select prompts from pool, instead of + selecting from + """ + if manual_prompt_indices is None: + similarity_matrix = self.similarity_fn(x, self.prompt_keys) + _, idx = torch.topk(similarity_matrix, k=self._N, dim=1) + if self._per_batch_prompt: + prompt_id, id_counts = torch.unique(idx, return_counts=True, sorted=True) + if prompt_id.shape[0] < self._M: + ## The logic for this is taken from the public l2p implementation. + temp_pid = torch.full((self._M,), idx.min(), device=prompt_id.device) + temp_pid[: prompt_id.shape[0]] = prompt_id + prompt_id = temp_pid + + temp_idc = torch.zeros((self._M,), device=id_counts.device) + temp_idc[: id_counts.shape[0]] = id_counts + id_counts = temp_idc + + _, major_idx = torch.topk(id_counts, k=self._N) + idx = prompt_id[major_idx].expand(x.shape[0], -1) # B, top_k + loss_value = similarity_matrix[:, idx].sum() / (x.shape[0] * x.shape[0]) + else: + idx = manual_prompt_indices # should be of size B, top_k + loss_value = torch.tensor(0.0, device=x.device) + + selected_prompts = self.prompt_pool[idx].flatten(1, 2) + return selected_prompts, loss_value + + +class LearningToPromptTransformer(RenateBenchmarkingModule): + """ + Implements the vision transformer with prompt pool described in + Wang, Zifeng, et al. "Learning to prompt for continual learning." Proceedings of the + IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. + + Args: + pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub + to use. If provided, it overrides other arguments about architecture. + image_size: Size of the input image. + patch_size: Size of the patches. + num_layers: Number of Encoder layers. + num_heads: Number of Attention heads. + hidden_dim: Size of the Encoder's hidden state. + mlp_dim: Size of the intermediate Multi-layer Perceptron in the Encoder. + dropout: Dropout probability. + attention_dropout: Dropout probability for the attention in the Multi-head Attention layer. + num_outputs: Size of the output. + prediction_strategy: Continual learning strategies may alter the prediction at train or test + time. + add_icarl_class_means: If ``True``, additional parameters used only by the + ``ICaRLModelUpdater`` are added. Only required when using that updater. + pool_size: Total size of the prompt pool. + pool_selection_size: Number of prompts to select from the pool. + prompt_size: Number of tokens each prompt is equivalent to. + prompt_key_dim: Dimensions of the prompt key used to compute similarity. It has to be same + to the dimensions of `x` in forward. + train_prompt_keys: Whether to train the prompt keys. Currently unused. + similarity_fn: Similarity function between input features and prompt keys. + per_batch_prompt: Flag to use the same prompts for all elements in the batch. + prompt_embedding_features: Image feature type used to compute the similarity to prompt keys. + patch_pooler: Features to feed the classifier. + """ + + def __init__( + self, + pretrained_model_name_or_path="google/vit-base-patch16-224", + image_size: int = 32, + patch_size: int = 4, + num_layers: int = 12, + num_heads: int = 12, + hidden_dim: int = 768, + mlp_dim: int = 3072, + dropout: float = 0.1, + attention_dropout: float = 0.1, + num_outputs: int = 10, + prediction_strategy: Optional[PredictionStrategy] = None, + add_icarl_class_means: bool = True, + pool_size: int = 10, + pool_selection_size: int = 5, + prompt_size: int = 5, + prompt_key_dim: int = 768, + train_prompt_keys: bool = True, + similarity_fn: Union[Callable, str] = "cosine", + per_batch_prompt: bool = True, + prompt_embedding_features: str = "cls", + patch_pooler: str = "prompt_mean", + ) -> None: + if "vit" in pretrained_model_name_or_path: + transformer = VisionTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + num_outputs=num_outputs, + ) + self._is_text_transformer = False + else: + transformer = HuggingFaceSequenceClassificationTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + num_outputs=num_outputs, + ) + + self._is_text_transformer = True + transformer._tasks_params.clear() + prompter = PromptPool( + embedding_dim=transformer._embedding_size, + pool_size=pool_size, + pool_selection_size=pool_selection_size, + prompt_size=prompt_size, + prompt_key_dim=prompt_key_dim, + train_prompt_keys=train_prompt_keys, + similarity_fn=similarity_fn, + per_batch_prompt=per_batch_prompt, + ) + + super().__init__( + embedding_size=transformer._embedding_size, + num_outputs=num_outputs, + constructor_arguments=dict( + **transformer._constructor_arguments, + pool_size=pool_size, + pool_selection_size=pool_selection_size, + prompt_size=prompt_size, + prompt_key_dim=prompt_key_dim, + train_prompt_keys=train_prompt_keys, + similarity_fn=similarity_fn, + per_batch_prompt=per_batch_prompt, + ), + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + + self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter}) + self.prompt_embedding_features = prompt_embedding_features + self.patch_pooler = patch_pooler + self.similarity_score: Optional[torch.Tensor] = None + + assert self.prompt_embedding_features in [ + "cls", + "mean", + ], f"Invalid method to extract prompt embedding features. Got {prompt_embedding_features}" + + assert self.patch_pooler in [ + "cls", + "mean", + "prompt_mean", + ], f"Invalid method to extract prompt embedding features. Got {patch_pooler}" + + for n, p in self._backbone["transformer"].named_parameters(): + p.requires_grad = False + self._backbone["transformer"].eval() + for p in self._backbone["prompter"].parameters(): + p.requires_grad = True + + if self._is_text_transformer: + ## This is to find the Embedding layer. + for named_param, value in self._backbone["transformer"].named_parameters(): + if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size: + self.word_embeddings = self._backbone["transformer"].get_submodule( + named_param.replace(".weight", "") + ) + break + # The backbone's forward is monkey-patched to allow the parent class' forward to work + # without any manual management. + self._backbone.forward = self.forward_for_monkey_patching + + def forward_for_monkey_patching( + self, x: torch.Tensor, task_id: str = defaults.TASK_ID + ) -> torch.Tensor: + if not self._is_text_transformer: + # The vision transformer code is manual strapping in. + with torch.no_grad(): + prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False) + if self.prompt_embedding_features == "cls": + # retrieve cls token features. This is used in L2P paper. + prompt_pool_input = prompt_pool_input[:, 0, :] + elif self.prompt_embedding_features == "mean": + # compute mean patch features. + prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1) + # Compute the prompts to be stacked + prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) + # compute patch embeddings + patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")( + x + ) + # concatenate both. + input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1) + ## rest of processing. this code is part of the ViTModel class in HF Transformers. + encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")( + input_concat_prompt, return_dict=False + )[0] + encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")( + encoded_features + ) + + ## Save similarity + self.similarity_score = prompt_similarity + + if self.patch_pooler == "cls": + seq_cls_token = encoded_features[:, 0, :] + elif self.patch_pooler == "mean": + seq_cls_token = encoded_features[:, 1:, :].mean(1) + elif self.patch_pooler == "prompt_mean": + num_prompts = prompts.size(1) + seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1) + return seq_cls_token + + else: + ## The implicit assumption here is that x for text transformers is the input_ids. + # This simplified forward pass has 4 steps: + # 1. Get prompts + # 2. Get embeddings from inputs. + # 3. Concat prompt and inputs + # 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies + # the classifier and gets logits. + with torch.no_grad(): + prompt_pool_input = self._backbone["transformer"].get_features(x) + prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1 + self.similarity_score = prompt_similarity + inputs_embeds = self.word_embeddings(x["input_ids"]) # 2 + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3 + return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4 diff --git a/src/renate/benchmark/models/transformer.py b/src/renate/benchmark/models/transformer.py index d1909233..4f4afc8b 100644 --- a/src/renate/benchmark/models/transformer.py +++ b/src/renate/benchmark/models/transformer.py @@ -11,9 +11,9 @@ class FeatureExtractorTextTransformer(PreTrainedModel): """This is a facade class to extract the correct output from the transformer model.""" - def __init__(self, pretrained_model_name: str): + def __init__(self, pretrained_model_name_or_path: str): model = AutoModelForTextEncoding.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name + pretrained_model_name_or_path=pretrained_model_name_or_path ) super().__init__(model.config) self._model = model @@ -30,7 +30,7 @@ class HuggingFaceSequenceClassificationTransformer(RenateBenchmarkingModule): """RenateBenchmarkingModule which wraps around Hugging Face transformers. Args: - pretrained_model_name: Hugging Face model id. + pretrained_model_name_or_path: Hugging Face model id. num_outputs: Number of outputs. prediction_strategy: Continual learning strategies may alter the prediction at train or test time. @@ -40,13 +40,15 @@ class HuggingFaceSequenceClassificationTransformer(RenateBenchmarkingModule): def __init__( self, - pretrained_model_name: str, + pretrained_model_name_or_path: str, num_outputs: int = 10, prediction_strategy: Optional[PredictionStrategy] = None, add_icarl_class_means: bool = True, ): - model = FeatureExtractorTextTransformer(pretrained_model_name=pretrained_model_name) - constructor_args = dict(pretrained_model_name=pretrained_model_name) + model = FeatureExtractorTextTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path + ) + constructor_args = dict(pretrained_model_name_or_path=pretrained_model_name_or_path) super().__init__( embedding_size=model.config.hidden_size, num_outputs=num_outputs, @@ -56,3 +58,6 @@ def __init__( ) self._backbone = model + + def get_features(self, *args, **kwargs): + return self._backbone(*args, **kwargs) diff --git a/src/renate/benchmark/models/vision_transformer.py b/src/renate/benchmark/models/vision_transformer.py index 4d0e8dce..79fc633d 100644 --- a/src/renate/benchmark/models/vision_transformer.py +++ b/src/renate/benchmark/models/vision_transformer.py @@ -6,39 +6,41 @@ from transformers import ViTConfig, ViTModel from transformers.modeling_outputs import BaseModelOutputWithPooling +from renate import defaults from renate.benchmark.models.base import RenateBenchmarkingModule from renate.models.prediction_strategies import PredictionStrategy class FeatureExtractorViTModel(ViTModel): - """This class directly outputs [CLS] features directly""" + """This class directly outputs [CLS] features if cls_feat is True else returns per patch + embeddings + """ def forward( self, pixel_values: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = None, return_dict: Optional[bool] = None, + cls_feat: bool = True, + task_id: str = defaults.TASK_ID, ) -> Union[Tuple, BaseModelOutputWithPooling]: """Output has patch embeddings and the pooled output. We extract pooled CLS out by taking the second element. """ out_to_filter = super().forward( pixel_values, - bool_masked_pos, head_mask, output_attentions, output_hidden_states, interpolate_pos_encoding, return_dict, ) - if isinstance(out_to_filter, BaseModelOutputWithPooling): - return out_to_filter.pooler_output - return out_to_filter[1] + return out_to_filter.pooler_output if cls_feat else out_to_filter.last_hidden_state + return out_to_filter[0][:, 0] if cls_feat else out_to_filter[0] class VisionTransformer(RenateBenchmarkingModule): @@ -84,8 +86,12 @@ def __init__( ) -> None: if pretrained_model_name_or_path: model = FeatureExtractorViTModel.from_pretrained( - pretrained_model_name_or_path=pretrained_model_name_or_path, return_dict=False + pretrained_model_name_or_path=pretrained_model_name_or_path, + return_dict=False, + add_pooling_layer=False, ) + + constructor_args = dict() else: model_config = ViTConfig( hidden_size=hidden_dim, @@ -103,12 +109,8 @@ def __init__( return_dict=False, ) - model = FeatureExtractorViTModel(config=model_config) - - super().__init__( - embedding_size=hidden_dim, - num_outputs=num_outputs, - constructor_arguments={ + model = FeatureExtractorViTModel(config=model_config, add_pooling_layer=False) + constructor_args = { "image_size": image_size, "patch_size": patch_size, "num_layers": num_layers, @@ -117,12 +119,22 @@ def __init__( "mlp_dim": mlp_dim, "dropout": dropout, "attention_dropout": attention_dropout, - }, + } + + super().__init__( + embedding_size=model.config.hidden_size, + num_outputs=num_outputs, + constructor_arguments=constructor_args, prediction_strategy=prediction_strategy, add_icarl_class_means=add_icarl_class_means, ) self._backbone = model + def get_features(self, *args, **kwargs): + # This is need as a shortcut to not call the base class's forward and directly call the + # backbone's forward. Used only in L2P. + return self._backbone(*args, **kwargs) + class VisionTransformerCIFAR(VisionTransformer): def __init__(self, **kwargs: Any) -> None: diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index fe4f33cd..0d94173e 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -22,6 +22,10 @@ from renate.updaters.experimental.fine_tuning import FineTuningModelUpdater from renate.updaters.experimental.gdumb import GDumbModelUpdater from renate.updaters.experimental.joint import JointModelUpdater +from renate.updaters.experimental.l2p import ( + LearningToPromptModelUpdater, + LearningToPromptReplayModelUpdater, +) from renate.updaters.experimental.offline_er import OfflineExperienceReplayModelUpdater from renate.updaters.experimental.repeated_distill import RepeatedDistillationModelUpdater from renate.updaters.model_updater import ModelUpdater @@ -60,6 +64,12 @@ def get_updater_and_learner_kwargs( if args.updater == "ER": learner_args = base_er_args + ["alpha"] updater_class = ExperienceReplayModelUpdater + elif args.updater == "LearningToPrompt": + learner_args = learner_args + ["prompt_sim_loss_weight"] + updater_class = LearningToPromptModelUpdater + elif args.updater == "LearningToPromptReplay": + learner_args = learner_args + ["prompt_sim_loss_weight", "memory_size", "memory_batch_size"] + updater_class = LearningToPromptReplayModelUpdater elif args.updater == "DER": learner_args = base_er_args + ["alpha", "beta"] updater_class = DarkExperienceReplayModelUpdater @@ -312,7 +322,7 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]: "true_type": bool, }, "gradient_clip_val": { - "type": lambda x: None if x == "None" else x, + "type": lambda x: None if x == "None" else float(x), "default": defaults.GRADIENT_CLIP_VAL, "help": "The value at which to clip gradients. None disables clipping.", "argument_group": OPTIONAL_ARGS_GROUP, @@ -460,6 +470,24 @@ def _add_base_experience_replay_arguments(arguments: Dict[str, Dict[str, Any]]) _add_replay_learner_arguments(arguments) +def _add_l2p_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: + arguments.update( + { + "prompt_sim_loss_weight": { + "type": float, + "default": defaults.PROMPT_SIM_LOSS_WEIGHT, + "help": "Prompt key similarity regularization weight. " + f"Default: {defaults.PROMPT_SIM_LOSS_WEIGHT}", + } + } + ) + + +def _add_l2preplay_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: + _add_l2p_arguments(arguments) + _add_offline_er_arguments(arguments) + + def _add_gdumb_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: """A helper function that adds GDumb arguments.""" _add_replay_learner_arguments(arguments) @@ -943,6 +971,8 @@ def get_scheduler_kwargs( parse_by_updater = { "ER": _add_experience_replay_arguments, + "LearningToPrompt": _add_l2p_arguments, + "LearningToPromptReplay": _add_l2preplay_arguments, "DER": _add_dark_experience_replay_arguments, "POD-ER": _add_pod_experience_replay_arguments, "CLS-ER": _add_cls_experience_replay_arguments, diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 12ed2983..e0854022 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -109,6 +109,9 @@ # Benchmark datasets/models TOKENIZER_KWARGS = {"padding": "max_length", "max_length": 128, "truncation": True} +# L2p +PROMPT_SIM_LOSS_WEIGHT = 0.5 + def scheduler(config_space: Dict[str, Any], mode: str, metric: str): return FIFOScheduler( diff --git a/src/renate/updaters/experimental/l2p.py b/src/renate/updaters/experimental/l2p.py new file mode 100644 index 00000000..92044813 --- /dev/null +++ b/src/renate/updaters/experimental/l2p.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import logging +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torchmetrics +from pytorch_lightning.loggers.logger import Logger +from pytorch_lightning.utilities.types import STEP_OUTPUT +from torch.nn import Parameter +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler + +from renate import defaults +from renate.benchmark.models.l2p import LearningToPromptTransformer +from renate.models import RenateModule +from renate.types import NestedTensors +from renate.updaters.experimental.offline_er import OfflineExperienceReplayLearner +from renate.updaters.learner import Learner +from renate.updaters.model_updater import SingleTrainingLoopUpdater +from renate.utils.misc import maybe_populate_mask_and_ignore_logits + +logger = logging.getLogger(__name__) + + +class LearningToPromptLearner(Learner): + """Learner for learning to prompt + + This is identical to the base learner with an addition of loss term. + TODO: Make this loss a component. + + Args: + prompt_sim_loss_weight: Loss weight for the prompt key - image representation similarity + """ + + def __init__( + self, + prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, + **kwargs, + ) -> None: + assert isinstance( + kwargs["model"], LearningToPromptTransformer + ), f"{self.__class__.__name__} can only train a LearningToPromptTransformer model" + f"but got {type(kwargs['model'])}" + super().__init__( + **kwargs, + ) + self.prompt_sim_loss_weight = prompt_sim_loss_weight + self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()}) + + def training_step( + self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int + ) -> STEP_OUTPUT: + loss_dict = super().training_step(batch, batch_idx=batch_idx) + key_similarity = -1 * self.prompt_sim_loss_weight * self._model.similarity_score + loss_dict["loss"] += key_similarity + self._loss_collections["train_losses"]["key_sim"](-key_similarity) + return loss_dict + + +class LearningToPromptReplayLearner(OfflineExperienceReplayLearner): + """L2P with an off-line ER learner. + + The model will be trained on weighted mixture of losses computed on the new data and a replay + buffer. In contrast to the online version, the buffer will only be updated after training has + terminated. + + Args: + prompt_sim_loss_weight: Loss weight for the prompt key - image representation similarity + """ + + def __init__( + self, + prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, + **kwargs, + ) -> None: + assert isinstance( + kwargs["model"], LearningToPromptTransformer + ), f"{self.__class__.__name__} can only train a LearningToPromptTransformer model" + f"but got {type(kwargs['model'])}" + + super().__init__(**kwargs) + self.prompt_sim_loss_weight = prompt_sim_loss_weight + self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()}) + + def training_step( + self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int + ) -> STEP_OUTPUT: + """PyTorch Lightning function to return the training loss.""" + # The reason for rewriting is to ensure two independent forward props of inputs and memory + # samples. LearningToPromptTransformer uses per_batch_prompt which uses a single prompt + # repeated across the batch. Hence, the separate processing of memory and input samples. + if self._loss_weight_new_data is None: + alpha = self._num_points_current_task / ( + self._num_points_current_task + self._num_points_previous_tasks + ) + else: + alpha = self._loss_weight_new_data + alpha = torch.tensor(alpha, device=batch["current_task"][0].device) + inputs, targets = batch["current_task"] + outputs = self(inputs) + + outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, + self._class_mask, + self._classes_in_current_task, + outputs, + ) + + if "memory" in batch: + (inputs_mem, targets_mem), _ = batch["memory"] + outputs_mem = self(inputs_mem) + + outputs_mem, self._class_mask = maybe_populate_mask_and_ignore_logits( + self._mask_unused_classes, + self._class_mask, + self._classes_in_current_task, + outputs_mem, + ) + + loss_current = self._loss_fn(outputs, targets).mean() + if "memory" in batch: + loss_memory = self._loss_fn(outputs_mem, targets_mem).mean() + self._loss_collections["train_losses"]["base_loss"](loss_current) + self._loss_collections["train_losses"]["memory_loss"](loss_memory) + loss = alpha * loss_current + (1.0 - alpha) * loss_memory + else: + loss = loss_current.mean() + self._loss_collections["train_losses"]["base_loss"](loss) + self._update_metrics(outputs, targets, "train") + + key_similarity = -1 * self.prompt_sim_loss_weight * self._model.similarity_score + loss += key_similarity + self._loss_collections["train_losses"]["key_sim"](-key_similarity) + return {"loss": loss} + + +class LearningToPromptModelUpdater(SingleTrainingLoopUpdater): + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[nn.Parameter]], Optimizer], + batch_size: int = defaults.BATCH_SIZE, + seed: int = defaults.SEED, + learner_kwargs: Optional[Dict[str, Any]] = None, + input_state_folder: Optional[str] = None, + output_state_folder: Optional[str] = None, + max_epochs: int = defaults.MAX_EPOCHS, + learning_rate_scheduler: Optional[Optional[Callable[[Optimizer], _LRScheduler]]] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + buffer_transform: Optional[Callable] = None, + buffer_target_transform: Optional[Callable] = None, + metric: Optional[str] = None, + mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + early_stopping_enabled: bool = False, + logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), + accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, + devices: Optional[int] = None, + strategy: Optional[str] = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, + deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, + ): + learner_kwargs = { + "batch_size": batch_size, + "seed": seed, + "loss_fn": loss_fn, + "prompt_sim_loss_weight": prompt_sim_loss_weight, + } + super().__init__( + model=model, + loss_fn=loss_fn, + optimizer=optimizer, + learner_class=LearningToPromptLearner, + learner_kwargs=learner_kwargs, + input_state_folder=input_state_folder, + output_state_folder=output_state_folder, + max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, + train_transform=train_transform, + train_target_transform=train_target_transform, + test_transform=test_transform, + test_target_transform=test_target_transform, + buffer_transform=buffer_transform, + buffer_target_transform=buffer_target_transform, + metric=metric, + mode=mode, + logged_metrics=logged_metrics, + early_stopping_enabled=early_stopping_enabled, + logger=logger, + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=precision, + deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, + ) + + +class LearningToPromptReplayModelUpdater(SingleTrainingLoopUpdater): + def __init__( + self, + model: RenateModule, + loss_fn: torch.nn.Module, + optimizer: Callable[[List[Parameter]], Optimizer], + memory_size: int, + batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC, + loss_weight_new_data: Optional[float] = None, + learning_rate_scheduler: Optional[partial] = None, + learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 + prompt_sim_loss_weight: float = defaults.PROMPT_SIM_LOSS_WEIGHT, + batch_size: int = defaults.BATCH_SIZE, + input_state_folder: Optional[str] = None, + output_state_folder: Optional[str] = None, + max_epochs: int = defaults.MAX_EPOCHS, + train_transform: Optional[Callable] = None, + train_target_transform: Optional[Callable] = None, + test_transform: Optional[Callable] = None, + test_target_transform: Optional[Callable] = None, + buffer_transform: Optional[Callable] = None, + buffer_target_transform: Optional[Callable] = None, + metric: Optional[str] = None, + mode: defaults.SUPPORTED_TUNING_MODE_TYPE = "min", + logged_metrics: Optional[Dict[str, torchmetrics.Metric]] = None, + early_stopping_enabled: bool = False, + logger: Logger = defaults.LOGGER(**defaults.LOGGER_KWARGS), + accelerator: defaults.SUPPORTED_ACCELERATORS_TYPE = defaults.ACCELERATOR, + devices: Optional[int] = None, + strategy: str = defaults.DISTRIBUTED_STRATEGY, + precision: str = defaults.PRECISION, + seed: int = defaults.SEED, + deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER, + gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL, + gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM, + mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES, + ): + learner_kwargs = { + "memory_size": memory_size, + "batch_memory_frac": batch_memory_frac, + "loss_weight_new_data": loss_weight_new_data, + "batch_size": batch_size, + "seed": seed, + "prompt_sim_loss_weight": prompt_sim_loss_weight, + } + super().__init__( + model, + loss_fn=loss_fn, + optimizer=optimizer, + learner_class=LearningToPromptReplayLearner, + learner_kwargs=learner_kwargs, + input_state_folder=input_state_folder, + output_state_folder=output_state_folder, + max_epochs=max_epochs, + learning_rate_scheduler=learning_rate_scheduler, + learning_rate_scheduler_interval=learning_rate_scheduler_interval, + train_transform=train_transform, + train_target_transform=train_target_transform, + test_transform=test_transform, + test_target_transform=test_target_transform, + buffer_transform=buffer_transform, + buffer_target_transform=buffer_target_transform, + metric=metric, + mode=mode, + logged_metrics=logged_metrics, + early_stopping_enabled=early_stopping_enabled, + logger=logger, + accelerator=accelerator, + devices=devices, + strategy=strategy, + precision=precision, + deterministic_trainer=deterministic_trainer, + gradient_clip_algorithm=gradient_clip_algorithm, + gradient_clip_val=gradient_clip_val, + mask_unused_classes=mask_unused_classes, + ) diff --git a/test/conftest.py b/test/conftest.py index 9d1fd5ce..974cce1c 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -7,7 +7,6 @@ import pytest import torch from pytorch_lightning.loggers import TensorBoardLogger - from renate import defaults from renate.benchmark.models import ( MultiLayerPerceptron, @@ -31,12 +30,11 @@ AvalancheLwFLearner, AvalancheReplayLearner, ) -from renate.updaters.avalanche.model_updater import ( - AvalancheModelUpdater, -) +from renate.updaters.avalanche.model_updater import AvalancheModelUpdater from renate.updaters.experimental.er import ExperienceReplayLearner from renate.updaters.experimental.gdumb import GDumbLearner from renate.updaters.experimental.joint import JointLearner +from renate.updaters.experimental.l2p import LearningToPromptLearner, LearningToPromptReplayLearner from renate.updaters.experimental.offline_er import OfflineExperienceReplayLearner from renate.updaters.experimental.repeated_distill import RepeatedDistillationLearner from renate.updaters.learner import Learner, ReplayLearner @@ -76,7 +74,20 @@ def pytest_collection_modifyitems(config, items): "seed": 1, }, Learner: {"batch_size": 10, "seed": 42}, - GDumbLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, + LearningToPromptLearner: {"batch_size": 10, "seed": 42, "prompt_sim_loss_weight": 1}, + LearningToPromptReplayLearner: { + "batch_size": 10, + "seed": 42, + "prompt_sim_loss_weight": 1, + "loss_weight_new_data": 0.5, + "memory_size": 30, + "batch_memory_frac": 0.3, + }, + GDumbLearner: { + "batch_size": 10, + "seed": 42, + "memory_size": 30, + }, JointLearner: {"batch_size": 10, "seed": 3}, RepeatedDistillationLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, OfflineExperienceReplayLearner: { @@ -126,6 +137,10 @@ def pytest_collection_modifyitems(config, items): JointLearner, OfflineExperienceReplayLearner, ] +L2P_LEARNERS = [ + LearningToPromptLearner, + LearningToPromptReplayLearner, +] SAMPLE_CLASSIFICATION_RESULTS = { "accuracy": [ diff --git a/test/renate/benchmark/models/test_l2p.py b/test/renate/benchmark/models/test_l2p.py new file mode 100644 index 00000000..a52b0a3b --- /dev/null +++ b/test/renate/benchmark/models/test_l2p.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool + + +def test_prompt_pool(): + D_emb = 2 + Lp = 3 + feat_dim = 6 + B = 4 + N = 2 + + pool = PromptPool( + embedding_dim=feat_dim, prompt_key_dim=D_emb, prompt_size=Lp, pool_selection_size=N + ) + input = torch.rand(B, D_emb) + out = pool(input)[0] + + assert out.shape == (B, Lp * N, feat_dim) + + +def test_prompted_vision_transformer(): + combined = LearningToPromptTransformer() + inp = torch.rand(1, 3, 224, 224) + assert combined(inp).shape == torch.Size((1, 10)) + + +def test_prompted_text_transformer(): + model = LearningToPromptTransformer(pretrained_model_name_or_path="bert-base-uncased") + inp = {"input_ids": torch.randint(0, 3000, (10, 128))} + assert model(inp).shape == torch.Size((10, 10)) + + +@pytest.mark.parametrize( + "cls,arg,argval,error", + [ + [PromptPool, "similarity_fn", "not_cosine", ValueError], + [LearningToPromptTransformer, "prompt_embedding_features", "not_cls", AssertionError], + [LearningToPromptTransformer, "patch_pooler", "not_cls", AssertionError], + ], +) +def test_pool_vision_transformer_raises_errors(cls, arg, argval, error): + with pytest.raises(error): + cls(**{arg: argval}) + + +@pytest.mark.parametrize( + "backbone,num_trainable_params", + [ + ["google/vit-base-patch16-224", 4], + ["google/vit-base-patch32-224-in21k", 4], + ["google/vit-large-patch32-224-in21k", 4], + ["bert-base-uncased", 4], + ["distilbert-base-uncased", 4], + ], +) +def test_prompt_vision_transformer_trainable_parameters(backbone, num_trainable_params): + # The result is always 4: prompt pool, prompt pool key, classifier wt, classifier bias. + model = LearningToPromptTransformer(pretrained_model_name_or_path=backbone) + n = sum(1 for x in model.parameters() if x.requires_grad) + assert n == num_trainable_params diff --git a/test/renate/benchmark/models/test_text_transformer.py b/test/renate/benchmark/models/test_text_transformer.py index d9426fe4..56c21699 100644 --- a/test/renate/benchmark/models/test_text_transformer.py +++ b/test/renate/benchmark/models/test_text_transformer.py @@ -8,7 +8,9 @@ @pytest.mark.parametrize("model_name", ["distilbert-base-uncased", "bert-base-uncased"]) def test_init(model_name): - HuggingFaceSequenceClassificationTransformer(pretrained_model_name=model_name, num_outputs=10) + HuggingFaceSequenceClassificationTransformer( + pretrained_model_name_or_path=model_name, num_outputs=10 + ) @pytest.mark.parametrize( @@ -19,7 +21,9 @@ def test_init(model_name): ], ) def test_text_transformer_fwd(model_name, input_dim): - transformer = HuggingFaceSequenceClassificationTransformer(pretrained_model_name=model_name) + transformer = HuggingFaceSequenceClassificationTransformer( + pretrained_model_name_or_path=model_name + ) x = {"input_ids": torch.randint(0, 30000, (5, *input_dim))} y_hat = transformer(x) diff --git a/test/renate/benchmark/models/test_vision_transformer.py b/test/renate/benchmark/models/test_vision_transformer.py index 8ebe19c6..5efba785 100644 --- a/test/renate/benchmark/models/test_vision_transformer.py +++ b/test/renate/benchmark/models/test_vision_transformer.py @@ -44,12 +44,12 @@ def test_renate_vision_transformer_fwd(sub_class, input_dim): @pytest.mark.parametrize( "sub_class, expected_num_params", [ - ["visiontransformercifar", 56], - ["visiontransformerb16", 200], - ["visiontransformerb32", 200], - ["visiontransformerl16", 392], - ["visiontransformerl32", 392], - ["visiontransformerh14", 520], + ["visiontransformercifar", 54], + ["visiontransformerb16", 198], + ["visiontransformerb32", 198], + ["visiontransformerl16", 390], + ["visiontransformerl32", 390], + ["visiontransformerh14", 518], ], ) def test_renate_vision_transformer_get_params(sub_class, expected_num_params): diff --git a/test/renate/benchmark/test_experimentation_config.py b/test/renate/benchmark/test_experimentation_config.py index 062420dc..1da7b27c 100644 --- a/test/renate/benchmark/test_experimentation_config.py +++ b/test/renate/benchmark/test_experimentation_config.py @@ -46,7 +46,7 @@ def test_model_fn(model_name, expected_model_class): num_outputs=2, num_hidden_layers=1 if model_name == "MultiLayerPerceptron" else None, hidden_size=1 if model_name == "MultiLayerPerceptron" else None, - pretrained_model_name="distilbert-base-uncased" + pretrained_model_name_or_path="distilbert-base-uncased" if model_name == "HuggingFaceTransformer" else None, ) @@ -82,7 +82,7 @@ def test_model_fn_fails_for_unknown_model(): @pytest.mark.parametrize( - "dataset_name,data_module_class,pretrained_model_name,input_column,target_column", + "dataset_name,data_module_class,pretrained_model_name_or_path,input_column,target_column", ( ("CIFAR10", TorchVisionDataModule, None, None, None), ("CLEAR10", CLEARDataModule, None, None, None), @@ -97,7 +97,12 @@ def test_model_fn_fails_for_unknown_model(): ), ) def test_get_data_module( - tmpdir, dataset_name, data_module_class, pretrained_model_name, input_column, target_column + tmpdir, + dataset_name, + data_module_class, + pretrained_model_name_or_path, + input_column, + target_column, ): data_module = get_data_module( data_path=tmpdir, @@ -106,7 +111,7 @@ def test_get_data_module( seed=0, src_bucket=None, src_object_name=None, - pretrained_model_name=pretrained_model_name, + pretrained_model_name_or_path=pretrained_model_name_or_path, input_column=input_column, target_column=target_column, ) @@ -123,7 +128,7 @@ def test_get_data_module_fails_for_unknown_dataset(tmpdir): seed=0, src_bucket=None, src_object_name=None, - pretrained_model_name=None, + pretrained_model_name_or_path=None, input_column=None, target_column=None, ) @@ -137,7 +142,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): seed=0, src_bucket=None, src_object_name=None, - pretrained_model_name=None, + pretrained_model_name_or_path=None, input_column=None, target_column=None, ) @@ -155,7 +160,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): "ClassIncrementalScenario", "hfd-trec", { - "pretrained_model_name": "distilbert-base-uncased", + "pretrained_model_name_or_path": "distilbert-base-uncased", "input_column": "text", "target_column": "coarse_label", "groupings": ((0, 1), (2, 3), (4, 5)), @@ -206,7 +211,7 @@ def test_get_scenario_fails_for_unknown_scenario(tmpdir): ( "DataIncrementalScenario", "arxiv", - {"num_tasks": 3, "pretrained_model_name": "distilbert-base-uncased"}, + {"num_tasks": 3, "pretrained_model_name_or_path": "distilbert-base-uncased"}, DataIncrementalScenario, 3, ), @@ -274,7 +279,7 @@ def test_data_module_fn( elif expected_scenario_class == HueShiftScenario: assert scenario._randomness == scenario_kwargs["randomness"] elif expected_scenario_class == DataIncrementalScenario: - if "pretrained_model_name" in scenario_kwargs: + if "pretrained_model_name_or_path" in scenario_kwargs: assert scenario._data_module._tokenizer is not None elif dataset_name not in ["CLEAR10", "CLEAR100", "DomainNet"]: assert scenario._data_module._tokenizer is None @@ -351,7 +356,7 @@ def test_prediction_strategy_is_correctly_set(model_name, updater): if model_name == "MultiLayerPerceptron": model_kwargs.update({"num_inputs": 10, "hidden_size": 10, "num_hidden_layers": 2}) elif model_name == "HuggingFaceTransformer": - model_kwargs["pretrained_model_name"] = "distilbert-base-uncased" + model_kwargs["pretrained_model_name_or_path"] = "distilbert-base-uncased" if model_name == "HuggingFaceTransformer" and updater == "Avalanche-iCaRL": with pytest.raises(ValueError, match="Transformers do not support iCaRL."): model_fn(**model_kwargs) diff --git a/test/renate/updaters/test_learner.py b/test/renate/updaters/test_learner.py index 0273702e..d53a5a7a 100644 --- a/test/renate/updaters/test_learner.py +++ b/test/renate/updaters/test_learner.py @@ -4,7 +4,8 @@ import pytest -from conftest import LEARNERS, LEARNER_KWARGS +from conftest import L2P_LEARNERS, LEARNERS, LEARNER_KWARGS +from renate.benchmark.models.l2p import LearningToPromptTransformer from renate.models import RenateModule from renate.updaters.learner import Learner @@ -39,7 +40,23 @@ def check_learner_variables(learner: Learner, expected_variable_values: Dict[str @pytest.mark.parametrize("learner_class", LEARNERS) def test_save_and_load_learner(learner_class): - model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) - checkpoint_dict = {} - learner.on_save_checkpoint(checkpoint=checkpoint_dict) - check_learner_variables(learner, checkpoint_dict) + if learner_class not in L2P_LEARNERS: + model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) + checkpoint_dict = {} + learner.on_save_checkpoint(checkpoint=checkpoint_dict) + check_learner_variables(learner, checkpoint_dict) + else: + with pytest.raises(AssertionError): + model, learner, learner_kwargs = get_model_and_learner_and_learner_kwargs(learner_class) + + learner_kwargs = LEARNER_KWARGS[learner_class] + model = LearningToPromptTransformer() + learner = learner_class( + model=model, + loss_fn=pytest.helpers.get_loss_fn(), + optimizer=pytest.helpers.get_partial_optimizer(), + **learner_kwargs, + ) + checkpoint_dict = {} + learner.on_save_checkpoint(checkpoint=checkpoint_dict) + check_learner_variables(learner, checkpoint_dict) From 538894ea73d424a7b47ce08643a23e9d1c2c0e03 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Wed, 23 Aug 2023 16:06:59 +0200 Subject: [PATCH 77/89] Fixing the issue with Domainnet redownloading (#389) --- src/renate/benchmark/datasets/vision_datasets.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index abd33af6..9dc7045a 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -312,11 +312,11 @@ class DomainNetDataModule(DataIncrementalDataModule): "infograph_train.txt": "379b50054f4ac2018dca4f89421b92d9", "infograph_test.txt": "779626b50869edffe8ea6941c3755c71", "painting.zip": "1ae32cdb4f98fe7ab5eb0a351768abfd", - "painting_train.txt": "b732ced3939ac8efdd8c0a889dca56cc", - "painting_test.txt": "c1a828fdfe216fb109f1c0083a252c6f", + "painting_train.txt": "7db0e7ca73ad9982f6e1f7f3ef372c0a", + "painting_test.txt": "232b35dc53f26d414686ae579e38d9b5", "quickdraw.zip": "bdc1b6f09f277da1a263389efe0c7a66", - "quickdraw_train.txt": "b4349693a7f9c05c53955725c47ed6cb", - "quickdraw_test.txt": "f5ddbcfd657a3acf9d0f7da10db22565", + "quickdraw_train.txt": "b732ced3939ac8efdd8c0a889dca56cc", + "quickdraw_test.txt": "c1a828fdfe216fb109f1c0083a252c6f", "real.zip": "dcc47055e8935767784b7162e7c7cca6", "real_train.txt": "8ebf02c2075fadd564705f0dc7cd6291", "real_test.txt": "6098816791c3ebed543c71ffa11b9054", @@ -354,6 +354,7 @@ def __init__( seed=seed, ) assert self.data_id in self.domains, f"Unknown domain {self.data_id}." + self._dataset_name = domain.lower() def prepare_data(self) -> None: """Download DomainNet dataset for given domain.""" From e2035fe4ba9ff57aaac283642a98beca7da8797c Mon Sep 17 00:00:00 2001 From: Giovanni <52964960+610v4nn1@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:17:55 +0200 Subject: [PATCH 78/89] Improve title for the NLP example (#416) --- doc/examples/nlp_finetuning.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/examples/nlp_finetuning.rst b/doc/examples/nlp_finetuning.rst index 311d409e..1eb39404 100644 --- a/doc/examples/nlp_finetuning.rst +++ b/doc/examples/nlp_finetuning.rst @@ -1,5 +1,5 @@ -Working with NLP -**************** +Working with NLP and Large Language Models +****************************************** This example demonstrates how to use Renate to train NLP models. We will train a sequence classifier to distinguish between positive and negative movie reviews. Using Renate, we will sequentially From 153027fdeae2799955e8b291029d4d6a0e70f621 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:25:50 +0200 Subject: [PATCH 79/89] Bump pytest from 7.4.0 to 7.4.2 (#413) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 249b7c43..91846728 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dev = [ "wild-time-data==0.1.1", "torch>=1.10.0, <1.12.2", # PyTest Dependencies - "pytest==7.4.0", + "pytest==7.4.2", "pytest-cov==4.1.0", "pytest-helpers-namespace==2021.12.29", ] From 9889433672dcdbfdb51c535cb6c09d396b890298 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 13 Sep 2023 14:27:24 +0200 Subject: [PATCH 80/89] Bump black from 23.3.0 to 23.9.1 (#407) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 91846728..5cf9e031 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ benchmark = [ "wild-time-data==0.1.1", ] dev = [ - "black==23.3.0", + "black==23.9.1", "avalanche_lib==0.3.1", "wild-time-data==0.1.1", "torch>=1.10.0, <1.12.2", From 6d3b42b32ce7d53dbb75bd9a08345ca67b610dcc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:25:11 +0200 Subject: [PATCH 81/89] Update scipy requirement from <1.11.2,>=1.9.0 to >=1.9.0,<1.11.3 (#382) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 517ccc5b..11922024 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,4 @@ torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 datasets>=2.9.0, <2.14.1 transformers>=4.31.0, <4.31.1 -scipy>=1.9.0, <1.11.2 +scipy>=1.9.0, <1.11.3 From 7ee47783d6949ec06743bbfb0d1992dd0521c5ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:41:26 +0200 Subject: [PATCH 82/89] Update numpy requirement from <1.25.2,>=1.17.2 to >=1.17.2,<1.25.3 (#375) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 11922024..ab6e6e2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.17.2, <1.25.2 +numpy>=1.17.2, <1.25.3 torch>=1.10.0, <1.13.2 pandas>=1.4.0, <2.0.4 boto3>=1.26.0, <1.26.139 From bb757b63ff253d5f690c241a1dd6644aadd04722 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 Sep 2023 16:42:15 +0200 Subject: [PATCH 83/89] Update tensorboardx requirement from <2.6.2,>=2.5.0 to >=2.5.0,<2.6.3 (#377) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ab6e6e2f..6effe82e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ syne-tune[aws,gpsearchers]==0.6.0 pytorch-lightning>=1.8.0, <1.9.5 Pillow>=9.0, <9.5.1 tabulate>=0.9.0, <0.9.1 -tensorboardX>=2.5.0, <2.6.2 +tensorboardX>=2.5.0, <2.6.3 torchmetrics>=0.11.0, <0.11.5 torchvision>=0.13.0, <0.15.2 deepspeed==0.9.1 From 6ada0b050ad6d94ae87d035f243bde194208d867 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:51:50 +0200 Subject: [PATCH 84/89] Update pandas requirement from <2.0.4,>=1.4.0 to >=1.4.0,<2.1.1 (#404) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6effe82e..e741917e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ numpy>=1.17.2, <1.25.3 torch>=1.10.0, <1.13.2 -pandas>=1.4.0, <2.0.4 +pandas>=1.4.0, <2.1.1 boto3>=1.26.0, <1.26.139 requests>=2.31.0, <2.31.1 sagemaker>=2.112.0, <2.158.1 From 5bc82b9405b8f02e91b1cb6870e1b7d99f0f3b78 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 15 Sep 2023 11:53:04 +0200 Subject: [PATCH 85/89] Bump actions/checkout from 3 to 4 (#398) --- .github/workflows/codeql.yml | 2 +- .github/workflows/integration_tests.yml | 2 +- .github/workflows/pypi_publish.yml | 2 +- .github/workflows/run_renate.yml | 2 +- .github/workflows/run_unit_tests.yml | 4 ++-- .github/workflows/sagemaker_tests.yml | 2 +- .github/workflows/test_docs.yml | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index ec9a97ed..d44c5670 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 3948c3c2..39df2329 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -21,7 +21,7 @@ jobs: outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: 3.x diff --git a/.github/workflows/pypi_publish.yml b/.github/workflows/pypi_publish.yml index cad730c9..94f89888 100644 --- a/.github/workflows/pypi_publish.yml +++ b/.github/workflows/pypi_publish.yml @@ -10,7 +10,7 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 with: fetch-depth: 0 - name: Set up Python 3.10 diff --git a/.github/workflows/run_renate.yml b/.github/workflows/run_renate.yml index 59f1dd5d..7083b261 100644 --- a/.github/workflows/run_renate.yml +++ b/.github/workflows/run_renate.yml @@ -40,7 +40,7 @@ jobs: role-to-assume: ${{ secrets.PROD_AWS_END_TO_END_TEST_ROLE_ARN }} role-session-name: integtestsession aws-region: ${{ env.AWS_DEFAULT_REGION }} - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 uses: actions/setup-python@v4 with: diff --git a/.github/workflows/run_unit_tests.yml b/.github/workflows/run_unit_tests.yml index 21a52301..0f2a0522 100644 --- a/.github/workflows/run_unit_tests.yml +++ b/.github/workflows/run_unit_tests.yml @@ -15,7 +15,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Switch to Current Branch run: git checkout ${{ env.BRANCH }} @@ -60,7 +60,7 @@ jobs: runs-on: ubuntu-latest needs: test steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/download-artifact@v3 id: download diff --git a/.github/workflows/sagemaker_tests.yml b/.github/workflows/sagemaker_tests.yml index c576bcd6..400ff9f0 100644 --- a/.github/workflows/sagemaker_tests.yml +++ b/.github/workflows/sagemaker_tests.yml @@ -23,7 +23,7 @@ jobs: launch-sagemaker-jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: actions/setup-python@v4 with: python-version: 3.9 diff --git a/.github/workflows/test_docs.yml b/.github/workflows/test_docs.yml index fa4b7d1d..1f980e56 100644 --- a/.github/workflows/test_docs.yml +++ b/.github/workflows/test_docs.yml @@ -13,7 +13,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Switch to Current Branch run: git checkout ${{ env.BRANCH }} From 864c1a45cb7e66458b6b5abb5e1a678f72338227 Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Tue, 19 Sep 2023 16:10:41 +0200 Subject: [PATCH 86/89] Abstracting prompting transformer for use in L2P and S-Prompt (#420) --- src/renate/benchmark/models/l2p.py | 220 +++++++++++++++-------- src/renate/updaters/experimental/l2p.py | 2 +- test/renate/benchmark/models/test_l2p.py | 26 ++- 3 files changed, 169 insertions(+), 79 deletions(-) diff --git a/src/renate/benchmark/models/l2p.py b/src/renate/benchmark/models/l2p.py index cd0f8379..175eb3e8 100644 --- a/src/renate/benchmark/models/l2p.py +++ b/src/renate/benchmark/models/l2p.py @@ -108,6 +108,123 @@ def forward(self, x: torch.Tensor, manual_prompt_indices: Optional[torch.LongTen return selected_prompts, loss_value +class PromptedTransformer(nn.Module): + """This generic module is the basic prompted transformer. It takes in a model string and creates + the appropriate transformer (ViT or Text transformer). If no prompts are provided in the forward + call, image/text features are returned. If a prompt is provided, it is concatenated to the + embedding layer output and the resultant features are returned. + + Args: + pretrained_model_name_or_path: A string that denotes which pretrained model from the HF hub + to use. + num_outputs: Size of the output. + prediction_strategy: Continual learning strategies may alter the prediction at train or test + time. + add_icarl_class_means: If ``True``, additional parameters used only by the + ``ICaRLModelUpdater`` are added. Only required when using that updater. + """ + + def __init__( + self, + pretrained_model_name_or_path="google/vit-base-patch16-224", + image_size: int = 32, + patch_size: int = 4, + num_layers: int = 12, + num_heads: int = 12, + hidden_dim: int = 768, + mlp_dim: int = 3072, + dropout: float = 0.1, + attention_dropout: float = 0.1, + num_outputs: int = 10, + prediction_strategy: Optional[PredictionStrategy] = None, + add_icarl_class_means: bool = True, + ) -> None: + super().__init__() + if "vit" in pretrained_model_name_or_path: + self.transformer = VisionTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_outputs=num_outputs, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + self.is_text_transformer = False + else: + self.transformer = HuggingFaceSequenceClassificationTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + num_outputs=num_outputs, + prediction_strategy=prediction_strategy, + add_icarl_class_means=add_icarl_class_means, + ) + for named_param, value in self.transformer.named_parameters(): + if value.shape[0] == self.transformer._backbone.config.vocab_size: + self.word_embeddings = self.transformer.get_submodule( + named_param.replace(".weight", "") + ) + break + + self.is_text_transformer = True + + self.transformer._tasks_params.clear() + self.transformer.eval() + for p in self.transformer.parameters(): + p.requires_grad_(False) + + def forward( + self, x: torch.Tensor, prompt: Optional[torch.Tensor] = None, cls_feat: bool = True + ) -> torch.Tensor: + """ + Args: + x: Input torch tensor. + prompt: Prompt tensor. Defaults to None. + cls_feat: Whether to extract [CLS] token or to return full feature tensor. + Ignored for text transformer. Defaults to True. + """ + if prompt is None: + return ( + self.transformer.get_features(x) + if self.is_text_transformer + else self.transformer.get_features(x, cls_feat=cls_feat) + ) + # text transformers dont support cls_feat. + elif self.is_text_transformer: + # The implicit assumption here is that x for text transformers is the input_ids. + # This simplified forward pass has 4 steps: + # 1. Get prompts + # 2. Get embeddings from inputs. + # 3. Concat prompt and inputs + # 4. Forward prop inputs_embeds to get the features. + inputs_embeds = self.word_embeddings(x["input_ids"]) + if prompt.size(0) != inputs_embeds.size(0): + prompt = prompt.unsqueeze(0).expand( + inputs_embeds.size(0), -1, -1 + ) # Expand one prompt to batch size + inputs_embeds = torch.cat((prompt, inputs_embeds), dim=1) + return self.transformer.get_features({"inputs_embeds": inputs_embeds}) + else: + patch_embeddings = self.transformer.get_submodule("_backbone.embeddings")(x) + if prompt.size(0) != x.size(0): + prompt = prompt.unsqueeze(0).expand( + x.size(0), -1, -1 + ) # Expand one prompt to batch size# Expand one prompt to batch size + input_concat_prompt = torch.cat([patch_embeddings, prompt], dim=1) + + encoded_features = self.transformer.get_submodule("_backbone.encoder")( + input_concat_prompt, return_dict=False + )[0] + encoded_features = self.transformer.get_submodule("_backbone.layernorm")( + encoded_features + ) + return encoded_features[:, 0, :] if cls_feat else encoded_features + + class LearningToPromptTransformer(RenateBenchmarkingModule): """ Implements the vision transformer with prompt pool described in @@ -166,34 +283,22 @@ def __init__( prompt_embedding_features: str = "cls", patch_pooler: str = "prompt_mean", ) -> None: - if "vit" in pretrained_model_name_or_path: - transformer = VisionTransformer( - pretrained_model_name_or_path=pretrained_model_name_or_path, - image_size=image_size, - patch_size=patch_size, - num_layers=num_layers, - num_heads=num_heads, - hidden_dim=hidden_dim, - mlp_dim=mlp_dim, - dropout=dropout, - attention_dropout=attention_dropout, - prediction_strategy=prediction_strategy, - add_icarl_class_means=add_icarl_class_means, - num_outputs=num_outputs, - ) - self._is_text_transformer = False - else: - transformer = HuggingFaceSequenceClassificationTransformer( - pretrained_model_name_or_path=pretrained_model_name_or_path, - prediction_strategy=prediction_strategy, - add_icarl_class_means=add_icarl_class_means, - num_outputs=num_outputs, - ) - - self._is_text_transformer = True - transformer._tasks_params.clear() + transformer = PromptedTransformer( + pretrained_model_name_or_path=pretrained_model_name_or_path, + image_size=image_size, + patch_size=patch_size, + num_layers=num_layers, + num_heads=num_heads, + hidden_dim=hidden_dim, + mlp_dim=mlp_dim, + dropout=dropout, + attention_dropout=attention_dropout, + num_outputs=num_outputs, + add_icarl_class_means=add_icarl_class_means, + prediction_strategy=prediction_strategy, + ) prompter = PromptPool( - embedding_dim=transformer._embedding_size, + embedding_dim=transformer.transformer._embedding_size, pool_size=pool_size, pool_selection_size=pool_selection_size, prompt_size=prompt_size, @@ -204,10 +309,10 @@ def __init__( ) super().__init__( - embedding_size=transformer._embedding_size, + embedding_size=transformer.transformer._embedding_size, num_outputs=num_outputs, constructor_arguments=dict( - **transformer._constructor_arguments, + **transformer.transformer._constructor_arguments, pool_size=pool_size, pool_selection_size=pool_selection_size, prompt_size=prompt_size, @@ -221,6 +326,7 @@ def __init__( ) self._backbone = nn.ModuleDict({"transformer": transformer, "prompter": prompter}) + self._is_text_transformer = transformer.is_text_transformer self.prompt_embedding_features = prompt_embedding_features self.patch_pooler = patch_pooler self.similarity_score: Optional[torch.Tensor] = None @@ -236,20 +342,9 @@ def __init__( "prompt_mean", ], f"Invalid method to extract prompt embedding features. Got {patch_pooler}" - for n, p in self._backbone["transformer"].named_parameters(): - p.requires_grad = False - self._backbone["transformer"].eval() for p in self._backbone["prompter"].parameters(): p.requires_grad = True - if self._is_text_transformer: - ## This is to find the Embedding layer. - for named_param, value in self._backbone["transformer"].named_parameters(): - if value.shape[0] == self._backbone["transformer"]._backbone.config.vocab_size: - self.word_embeddings = self._backbone["transformer"].get_submodule( - named_param.replace(".weight", "") - ) - break # The backbone's forward is monkey-patched to allow the parent class' forward to work # without any manual management. self._backbone.forward = self.forward_for_monkey_patching @@ -257,10 +352,9 @@ def __init__( def forward_for_monkey_patching( self, x: torch.Tensor, task_id: str = defaults.TASK_ID ) -> torch.Tensor: + with torch.no_grad(): + prompt_pool_input = self._backbone["transformer"](x, cls_feat=False) if not self._is_text_transformer: - # The vision transformer code is manual strapping in. - with torch.no_grad(): - prompt_pool_input = self._backbone["transformer"].get_features(x, cls_feat=False) if self.prompt_embedding_features == "cls": # retrieve cls token features. This is used in L2P paper. prompt_pool_input = prompt_pool_input[:, 0, :] @@ -268,24 +362,12 @@ def forward_for_monkey_patching( # compute mean patch features. prompt_pool_input = prompt_pool_input[:, 1:, :].mean(1) # Compute the prompts to be stacked - prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) - # compute patch embeddings - patch_embeddings = self._backbone["transformer"].get_submodule("_backbone.embeddings")( - x - ) - # concatenate both. - input_concat_prompt = torch.cat([patch_embeddings, prompts], dim=1) - ## rest of processing. this code is part of the ViTModel class in HF Transformers. - encoded_features = self._backbone["transformer"].get_submodule("_backbone.encoder")( - input_concat_prompt, return_dict=False - )[0] - encoded_features = self._backbone["transformer"].get_submodule("_backbone.layernorm")( - encoded_features - ) - - ## Save similarity - self.similarity_score = prompt_similarity - + prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) + self.similarity_score = prompt_similarity + encoded_features = self._backbone["transformer"](x, prompts, cls_feat=False) + if self._is_text_transformer: + return encoded_features + else: if self.patch_pooler == "cls": seq_cls_token = encoded_features[:, 0, :] elif self.patch_pooler == "mean": @@ -294,19 +376,3 @@ def forward_for_monkey_patching( num_prompts = prompts.size(1) seq_cls_token = encoded_features[:, -num_prompts:, :].mean(1) return seq_cls_token - - else: - ## The implicit assumption here is that x for text transformers is the input_ids. - # This simplified forward pass has 4 steps: - # 1. Get prompts - # 2. Get embeddings from inputs. - # 3. Concat prompt and inputs - # 4. Forward prop inputs_embeds to get the features. The forward of the RenateBM applies - # the classifier and gets logits. - with torch.no_grad(): - prompt_pool_input = self._backbone["transformer"].get_features(x) - prompts, prompt_similarity = self._backbone["prompter"](prompt_pool_input) # 1 - self.similarity_score = prompt_similarity - inputs_embeds = self.word_embeddings(x["input_ids"]) # 2 - inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) # 3 - return self._backbone["transformer"].get_features({"inputs_embeds": inputs_embeds}) # 4 diff --git a/src/renate/updaters/experimental/l2p.py b/src/renate/updaters/experimental/l2p.py index 92044813..ceed6ecc 100644 --- a/src/renate/updaters/experimental/l2p.py +++ b/src/renate/updaters/experimental/l2p.py @@ -48,7 +48,7 @@ def __init__( **kwargs, ) self.prompt_sim_loss_weight = prompt_sim_loss_weight - self._loss_collections["train_losses"].update({"key_sim_loss": torchmetrics.MeanMetric()}) + self._loss_collections["train_losses"].update({"key_sim": torchmetrics.MeanMetric()}) def training_step( self, batch: Tuple[NestedTensors, torch.Tensor], batch_idx: int diff --git a/test/renate/benchmark/models/test_l2p.py b/test/renate/benchmark/models/test_l2p.py index a52b0a3b..44e0be0b 100644 --- a/test/renate/benchmark/models/test_l2p.py +++ b/test/renate/benchmark/models/test_l2p.py @@ -3,7 +3,7 @@ import pytest import torch -from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool +from renate.benchmark.models.l2p import LearningToPromptTransformer, PromptPool, PromptedTransformer def test_prompt_pool(): @@ -62,3 +62,27 @@ def test_prompt_vision_transformer_trainable_parameters(backbone, num_trainable_ model = LearningToPromptTransformer(pretrained_model_name_or_path=backbone) n = sum(1 for x in model.parameters() if x.requires_grad) assert n == num_trainable_params + + +@pytest.mark.parametrize("backbone", ["google/vit-base-patch16-224", "bert-base-uncased"]) +@pytest.mark.parametrize("prompt", [None, torch.rand(3, 10, 768)]) +@pytest.mark.parametrize("cls_feat", [True, False]) +def test_prompted_transformer(backbone, prompt, cls_feat): + model = PromptedTransformer( + pretrained_model_name_or_path=backbone, + num_outputs=10, + prediction_strategy=None, + add_icarl_class_means=False, + ) + + B, P_len, _ = prompt.shape if prompt is not None else (5, 0, 0) + if "vit" in backbone: + # we are doing ViT. + inputs = torch.rand(B, 3, 224, 224) + expected_output_size = (B, 197 + P_len, 768) if not cls_feat else (B, 768) + else: + inputs = {"input_ids": torch.randint(0, 10000, (B, 128))} + expected_output_size = (B, 768) + + out = model(inputs, prompt, cls_feat=cls_feat) + assert out.shape == expected_output_size From 073d6a2a3bf4e3e4dd1cd3d591ae6676890ff94b Mon Sep 17 00:00:00 2001 From: Prabhu Teja Date: Fri, 22 Sep 2023 10:33:32 +0200 Subject: [PATCH 87/89] CLEAR dataset download link update (#431) --- src/renate/benchmark/datasets/vision_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renate/benchmark/datasets/vision_datasets.py b/src/renate/benchmark/datasets/vision_datasets.py index 9dc7045a..7368ae33 100644 --- a/src/renate/benchmark/datasets/vision_datasets.py +++ b/src/renate/benchmark/datasets/vision_datasets.py @@ -251,7 +251,7 @@ def prepare_data(self) -> None: self._data_path, self._src_bucket, self._src_object_name, - "https://clear-challenge.s3.us-east-2.amazonaws.com/", + "https://huggingface.co/datasets/elvishelvis6/CLEAR-Continual_Learning_Benchmark/resolve/main/", # noqa: E501 file_name, ) From f99cf1766990f49824191dc09f46484ecd389f6b Mon Sep 17 00:00:00 2001 From: wistuba Date: Fri, 22 Sep 2023 11:49:34 +0200 Subject: [PATCH 88/89] Refactor Offline-ER to work with `collate_fn` (#390) --- src/renate/memory/buffer.py | 4 +- .../updaters/experimental/offline_er.py | 57 ++++++----- src/renate/updaters/learner.py | 2 +- src/renate/utils/pytorch.py | 97 ++++++++++++++++++- .../configs/suites/quick/offline-er.json | 4 +- test/renate/utils/test_pytorch.py | 40 +++++++- 6 files changed, 168 insertions(+), 36 deletions(-) diff --git a/src/renate/memory/buffer.py b/src/renate/memory/buffer.py index 7b762a27..57d75d22 100644 --- a/src/renate/memory/buffer.py +++ b/src/renate/memory/buffer.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy from collections import defaultdict -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch.utils.data import Dataset @@ -67,7 +67,7 @@ def __len__(self) -> int: """Returns the current length of the buffer.""" return len(self._indices) - def __getitem__(self, idx: int) -> NestedTensors: + def __getitem__(self, idx: int) -> Tuple[NestedTensors, Dict[str, Any]]: """Reads the item at index `idx` of the buffer.""" i, j = self._indices[idx] data = self._datasets[i][j] diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index c541faea..84364634 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -6,19 +6,19 @@ import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn import Parameter from torch.optim import Optimizer -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import ConcatDataset, DataLoader, Dataset from renate import defaults +from renate.memory import ReservoirBuffer from renate.models import RenateModule from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater from renate.utils.misc import maybe_populate_mask_and_ignore_logits -from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors +from renate.utils.pytorch import ConcatRandomSampler class OfflineExperienceReplayLearner(ReplayLearner): @@ -71,24 +71,29 @@ def on_model_update_start( self._num_points_current_task = len(train_dataset) def train_dataloader(self) -> DataLoader: - loaders = {} if len(self._memory_buffer) > self._memory_batch_size: - loaders["current_task"] = super().train_dataloader() - loaders["memory"] = DataLoader( - dataset=self._memory_buffer, - batch_size=self._memory_batch_size, - drop_last=True, - shuffle=True, + train_buffer = ReservoirBuffer( + max_size=self._num_points_current_task, + seed=0, + transform=self._train_transform, + target_transform=self._train_target_transform, + ) + train_buffer.update(self._train_dataset) + return DataLoader( + dataset=ConcatDataset([train_buffer, self._memory_buffer]), generator=self._rng, pin_memory=True, collate_fn=self._train_collate_fn, + batch_sampler=ConcatRandomSampler( + [self._num_points_current_task, len(self._memory_buffer)], + [self._batch_size, self._memory_batch_size], + 0, + generator=self._rng, + ), ) - else: - batch_size = self._batch_size - self._batch_size += self._memory_batch_size - loaders["current_task"] = super().train_dataloader() - self._batch_size = batch_size - return CombinedLoader(loaders, mode="max_size_cycle") + self._batch_size += self._memory_batch_size + self._memory_batch_size = 0 + return super().train_dataloader() def on_model_update_end(self) -> None: """Called right before a model update terminates.""" @@ -96,7 +101,9 @@ def on_model_update_end(self) -> None: self._num_points_previous_tasks += self._num_points_current_task self._num_points_current_task = -1 - def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT: + def training_step( + self, batch: Tuple[NestedTensors, Dict[str, Any]], batch_idx: int + ) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" if self._loss_weight_new_data is None: alpha = self._num_points_current_task / ( @@ -105,21 +112,19 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) else: alpha = self._loss_weight_new_data alpha = torch.tensor(alpha, device=next(self.parameters()).device) - inputs, targets = batch["current_task"] - batch_size_current = get_length_nested_tensors(inputs) - if "memory" in batch: - (inputs_mem, targets_mem), _ = batch["memory"] - inputs = cat_nested_tensors((inputs, inputs_mem), 0) - targets = torch.cat((targets, targets_mem), 0) + if self._memory_batch_size: + (inputs, targets), _ = batch + else: + inputs, targets = batch outputs = self(inputs) outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs ) loss = self._loss_fn(outputs, targets) - if "memory" in batch: - loss_current = loss[:batch_size_current].mean() - loss_memory = loss[batch_size_current:].mean() + if self._memory_batch_size: + loss_current = loss[: self._batch_size].mean() + loss_memory = loss[self._batch_size :].mean() self._loss_collections["train_losses"]["base_loss"](loss_current) self._loss_collections["train_losses"]["memory_loss"](loss_memory) loss = alpha * loss_current + (1.0 - alpha) * loss_memory diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index d7e8777f..3180d034 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -153,7 +153,7 @@ def on_model_update_start( ) -> None: self._train_dataset = train_dataset self._val_dataset = val_dataset - self.val_enabled = val_dataset is not None and len(val_dataset) + self.val_enabled = val_dataset is not None and len(val_dataset) > 0 self._train_collate_fn = train_dataset_collate_fn self._val_collate_fn = val_dataset_collate_fn self._task_id = task_id diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 30a5738d..c143765a 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import List, Optional, Set, Tuple, Union +from typing import Any, Iterator, List, Optional, Set, Tuple, Union import torch -from torch.utils.data import Dataset, random_split +from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split from transformers import BatchEncoding from renate import defaults @@ -20,11 +20,11 @@ def reinitialize_model_parameters(model: torch.nn.Module) -> None: implementations of exotic layers. A warning is logged for modules that do not implement `reset_parameters()`. - The actual logic of renitializing parameters depends on the type of layer. It may affect the + The actual logic of reinitializing parameters depends on the type of layer. It may affect the module's buffers (non-trainable parameters, e.g., batch norm stats) as well. Args: - model: The model to be re-initialized. + model: The model to be reinitialized. """ for module in model.modules(): # Skip modules without any parameters of their own. @@ -156,3 +156,92 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int valid_classes: A set of integers of valid classes. """ return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] + + +class ConcatRandomSampler(Sampler[List[int]]): + """Sampler for sampling batches from ConcatDatasets. + + Each sampled batch is composed of batches of different BatchSamplers with the specified + batch sizes and ranges. + + To clarify the behavior, we provide a little example. + ``dataset_lengths = [5, 2]`` + ``batch_sizes = [3, 1]`` + + With this setting, we have a set of indices A={0..4} and B={5,6} for the two datasets. + The total batch size will be exactly 4. The first three elements are in that batch are + elements of A, the last an element of B. + An example batch could be ``[3, 1, 0, 6]``. + + Since we always provide a batch size of exactly ` sum(batch_sizes)``, we drop the last + batch. + + + Args: + dataset_lengths: The length for the different datasets. + batch_sizes: Batch sizes used for specific datasets. + complete_dataset_iteration: Provide an index to indicate over which dataset to fully + iterate. By default, stops whenever iteration is complete for any dataset. + generator: Generator used in sampling. + sampler: Lightning automatically passes a DistributedSamplerWrapper. Only used as an + indicator that we are in the distributed case. + """ + + def __init__( + self, + dataset_lengths: List[int], + batch_sizes: List[int], + complete_dataset_iteration: Optional[int] = None, + generator: Any = None, + sampler: Sampler = None, + ) -> None: + self.batch_sizes = batch_sizes + self.complete_dataset_iteration = complete_dataset_iteration + self.subset_samplers = [] + data_start_idx = 0 + num_batches = [] + rank = 0 if sampler is None else sampler.rank + num_replicas = 1 if sampler is None else sampler.num_replicas + for dataset_length, batch_size in zip(dataset_lengths, batch_sizes): + data_end_idx = data_start_idx + dataset_length + start_idx = data_start_idx + round(dataset_length / num_replicas * rank) + end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1)) + subset_sampler = BatchSampler( + SubsetRandomSampler(list(range(start_idx, end_idx)), generator), + batch_size, + True, + ) + self.subset_samplers.append(subset_sampler) + num_batches.append((end_idx - start_idx) // batch_size) + data_start_idx = data_end_idx + self.length = ( + min(num_batches) + if self.complete_dataset_iteration is None + else num_batches[self.complete_dataset_iteration] + ) + + def __iter__(self) -> Iterator[List[int]]: + """Creates a batch with groups of indices where each group corresponds to one dataset.""" + if self.complete_dataset_iteration is None: + # Default case is iterating once over the shortest iterator. Works nicely with zip. + for samples in zip(*self.subset_samplers): + yield [j for i in samples for j in i] + else: + # Iterating over a specific iterator requires dealing with the length of other + # iterators. + iterators = [iter(sampler) for sampler in self.subset_samplers] + for s in iterators[self.complete_dataset_iteration]: + samples = [] + for i, iterator in enumerate(iterators): + if i != self.complete_dataset_iteration: + try: + samples += next(iterator) + except StopIteration: + iterators[i] = iter(self.subset_samplers[i]) + samples += next(iterators[i]) + else: + samples += s + yield samples + + def __len__(self): + return self.length diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 3732b00e..3ab5cfc4 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]], - "expected_accuracy_darwin": [[0.7315000295639038, 0.49000000953674316]] + "expected_accuracy_linux": [[0.7634999752044678, 0.40299999713897705], [0.6234999895095825, 0.3779999911785126]], + "expected_accuracy_darwin": [[0.7279999852180481, 0.4650000035762787]] } diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index f96bcfcd..ab1cfa5b 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -4,13 +4,14 @@ import pytest import torch import torchvision -from torch.utils.data import TensorDataset +from torch.utils.data import Sampler, TensorDataset from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.benchmark.scenarios import ClassIncrementalScenario from renate.memory.buffer import ReservoirBuffer from renate.utils import pytorch from renate.utils.pytorch import ( + ConcatRandomSampler, cat_nested_tensors, complementary_indices, get_length_nested_tensors, @@ -150,3 +151,40 @@ def test_unique_classes(tmpdir, test_dataset): buffer.update(ds, metadata) predicted_unique = unique_classes(buffer) assert predicted_unique == set(list(range(10))) + + +@pytest.mark.parametrize( + "complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]] +) +def test_concat_random_sampler(complete_dataset_iteration, expected_batches): + sampler = ConcatRandomSampler( + dataset_lengths=[15, 5, 20], + batch_sizes=[2, 1, 8], + complete_dataset_iteration=complete_dataset_iteration, + ) + assert len(sampler) == expected_batches + num_batches = 0 + for sample in sampler: + assert all([s < 15 for s in sample[:2]]) + assert all([15 <= s < 20 for s in sample[2:3]]) + assert all([20 <= s < 40 for s in sample[3:]]) + num_batches += 1 + assert num_batches == expected_batches + + +def test_concat_random_sampler_distributed(): + """Tests behavior in case of distributed computing.""" + mock_sampler = Sampler(None) + mock_sampler.rank = 1 + mock_sampler.num_replicas = 2 + expected_batches = 2 + sampler = ConcatRandomSampler( + dataset_lengths=[16, 10], batch_sizes=[2, 2], sampler=mock_sampler + ) + assert len(sampler) == expected_batches + num_batches = 0 + for sample in sampler: + assert all([7 < s < 16 for s in sample[:2]]) + assert all([21 <= s < 26 for s in sample[2:]]) + num_batches += 1 + assert num_batches == expected_batches From 60ea7b6a1ecfb8e3a274d5ea5d75a92294db234e Mon Sep 17 00:00:00 2001 From: wistuba Date: Mon, 25 Sep 2023 13:53:42 +0200 Subject: [PATCH 89/89] Prepare Release 0.4.0 (#429) --- .github/workflows/run_renate.yml | 2 +- .github/workflows/sagemaker_tests.yml | 2 +- README.rst | 12 +- .../datasets/multitext.json | 4 + .../experiment_configs/datasets/yearbook.json | 7 + .../fine-tuning-clear10-vitb16.json | 6 + .../fine-tuning-clear100-vitb16.json | 6 + .../fine-tuning-domainnet-vitb16.json | 6 + .../fine-tuning-fmow-vitb16.json | 6 + .../fine-tuning-multitext.json | 6 + .../fine-tuning-yearbook-vitb16.json | 6 + .../fine-tuning-yearbook.json | 6 + .../joint-cifar100-vitb16.json | 6 + .../joint-clear10-vitb16.json | 6 + .../joint-clear100-vitb16.json | 6 + .../joint-domainnet-vitb16.json | 6 + .../experiment_configs/joint-fmow-vitb16.json | 6 + .../experiment_configs/joint-multitext.json | 6 + .../joint-yearbook-vitb16.json | 6 + .../experiment_configs/joint-yearbook.json | 6 + .../experiment_configs/models/vit-b16.json | 3 + .../experiment_configs/offline-er-arxiv.json | 6 + .../offline-er-clear10-vitb16.json | 6 + .../offline-er-clear100-vitb16.json | 6 + .../offline-er-domainnet-vitb16.json | 6 + .../offline-er-fmow-vitb16.json | 6 + .../experiment_configs/offline-er-fmow.json | 6 + .../offline-er-huffpost.json | 6 + .../offline-er-multitext.json | 6 + .../offline-er-yearbook-vitb16.json | 6 + .../offline-er-yearbook.json | 6 + .../scenarios/clear10-vit-b16-10updates.json | 14 + .../scenarios/clear100-vit-b16-11updates.json | 14 + .../scenarios/domainnet-vit-b16-6updates.json | 15 + .../scenarios/fmow-vit-b16-16updates.json | 14 + .../scenarios/multitext-4updates.json | 10 + .../scenarios/yearbook-17updates.json | 33 + .../scenarios/yearbook-vit-b16-17updates.json | 28 + .../updaters/fine-tuning-clear-vit-b16.json | 4 + .../updaters/fine-tuning-multitext.json | 4 + .../updaters/fine-tuning-yearbook.json | 4 + .../updaters/joint-clear-vit-b16.json | 4 + .../updaters/joint-multitext.json | 4 + .../updaters/joint-yearbook.json | 4 + .../updaters/l2p-clear-vit-b16.json | 5 + .../updaters/l2p-domainnet-vit-b16.json | 5 + .../updaters/l2p-fmow-vit-b16.json | 5 + .../updaters/l2p-yearbook.json | 5 + .../updaters/offline-er-clear10-vit-b16.json | 6 + .../updaters/offline-er-clear100-vit-b16.json | 6 + .../updaters/offline-er-fmow.json | 6 + .../updaters/offline-er-huffpost.json | 6 + .../updaters/offline-er-multitext.json | 6 + .../updaters/offline-er-yearbook.json | 6 + doc/_images/improvement_renate.svg | 648 +++++++++--------- doc/requirements.txt | 10 +- requirements.txt | 18 +- src/renate/__init__.py | 2 +- 58 files changed, 729 insertions(+), 337 deletions(-) create mode 100644 benchmarks/experiment_configs/datasets/multitext.json create mode 100644 benchmarks/experiment_configs/datasets/yearbook.json create mode 100644 benchmarks/experiment_configs/fine-tuning-clear10-vitb16.json create mode 100644 benchmarks/experiment_configs/fine-tuning-clear100-vitb16.json create mode 100644 benchmarks/experiment_configs/fine-tuning-domainnet-vitb16.json create mode 100644 benchmarks/experiment_configs/fine-tuning-fmow-vitb16.json create mode 100644 benchmarks/experiment_configs/fine-tuning-multitext.json create mode 100644 benchmarks/experiment_configs/fine-tuning-yearbook-vitb16.json create mode 100644 benchmarks/experiment_configs/fine-tuning-yearbook.json create mode 100644 benchmarks/experiment_configs/joint-cifar100-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-clear10-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-clear100-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-domainnet-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-fmow-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-multitext.json create mode 100644 benchmarks/experiment_configs/joint-yearbook-vitb16.json create mode 100644 benchmarks/experiment_configs/joint-yearbook.json create mode 100644 benchmarks/experiment_configs/models/vit-b16.json create mode 100644 benchmarks/experiment_configs/offline-er-arxiv.json create mode 100644 benchmarks/experiment_configs/offline-er-clear10-vitb16.json create mode 100644 benchmarks/experiment_configs/offline-er-clear100-vitb16.json create mode 100644 benchmarks/experiment_configs/offline-er-domainnet-vitb16.json create mode 100644 benchmarks/experiment_configs/offline-er-fmow-vitb16.json create mode 100644 benchmarks/experiment_configs/offline-er-fmow.json create mode 100644 benchmarks/experiment_configs/offline-er-huffpost.json create mode 100644 benchmarks/experiment_configs/offline-er-multitext.json create mode 100644 benchmarks/experiment_configs/offline-er-yearbook-vitb16.json create mode 100644 benchmarks/experiment_configs/offline-er-yearbook.json create mode 100644 benchmarks/experiment_configs/scenarios/clear10-vit-b16-10updates.json create mode 100644 benchmarks/experiment_configs/scenarios/clear100-vit-b16-11updates.json create mode 100644 benchmarks/experiment_configs/scenarios/domainnet-vit-b16-6updates.json create mode 100644 benchmarks/experiment_configs/scenarios/fmow-vit-b16-16updates.json create mode 100644 benchmarks/experiment_configs/scenarios/multitext-4updates.json create mode 100644 benchmarks/experiment_configs/scenarios/yearbook-17updates.json create mode 100644 benchmarks/experiment_configs/scenarios/yearbook-vit-b16-17updates.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-clear-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-multitext.json create mode 100644 benchmarks/experiment_configs/updaters/fine-tuning-yearbook.json create mode 100644 benchmarks/experiment_configs/updaters/joint-clear-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/joint-multitext.json create mode 100644 benchmarks/experiment_configs/updaters/joint-yearbook.json create mode 100644 benchmarks/experiment_configs/updaters/l2p-clear-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/l2p-domainnet-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/l2p-fmow-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/l2p-yearbook.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-clear10-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-clear100-vit-b16.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-fmow.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-huffpost.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-multitext.json create mode 100644 benchmarks/experiment_configs/updaters/offline-er-yearbook.json diff --git a/.github/workflows/run_renate.yml b/.github/workflows/run_renate.yml index 7083b261..37f24f14 100644 --- a/.github/workflows/run_renate.yml +++ b/.github/workflows/run_renate.yml @@ -35,7 +35,7 @@ jobs: steps: - name: Configure AWS Credentials (if required) if: ${{ inputs.requires-aws-credentials == true }} - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.PROD_AWS_END_TO_END_TEST_ROLE_ARN }} role-session-name: integtestsession diff --git a/.github/workflows/sagemaker_tests.yml b/.github/workflows/sagemaker_tests.yml index 400ff9f0..06695734 100644 --- a/.github/workflows/sagemaker_tests.yml +++ b/.github/workflows/sagemaker_tests.yml @@ -38,7 +38,7 @@ jobs: run: | python test/integration_tests/generate_requirements.py - name: Get Credentials - uses: aws-actions/configure-aws-credentials@v2 + uses: aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.PROD_AWS_END_TO_END_TEST_ROLE_ARN }} role-session-name: integtestsession diff --git a/README.rst b/README.rst index 30f6028e..19b83d44 100644 --- a/README.rst +++ b/README.rst @@ -2,7 +2,7 @@ :target: # :alt: PyPI - Status .. image:: https://img.shields.io/github/v/release/awslabs/Renate - :target: https://github.com/awslabs/Renate/releases/tag/v0.3.1 + :target: https://github.com/awslabs/Renate/releases/tag/v0.4.0 :alt: Latest Release .. image:: https://img.shields.io/pypi/dm/Renate :target: https://pypistats.org/packages/renate @@ -116,14 +116,14 @@ If you did not find what you were looking for, open an `issue `_. + The training data was divided by year, and we trained sequentially on them. Fine-tuning refers to the strategy to learn on the first partition from scratch, and train on each of the subsequent partitions for few epochs only. - We compare to Experience Replay with a memory size of 500. - For both methods we use the same number of epochs and choose the best checkpoint + We compare to Experience Replay with an infinite memory size. + For both methods we use the same amount of training time and choose the best checkpoint using a validation set. Results reported are on the test set. -.. [#] The setup is the same as in the last experiment. However, this time we compare +.. [#] In this experiment, we consider class-incremental learning on CIFAR-10. We compare Experience Replay against a version in which its hyperparameters were tuned. diff --git a/benchmarks/experiment_configs/datasets/multitext.json b/benchmarks/experiment_configs/datasets/multitext.json new file mode 100644 index 00000000..7bc94002 --- /dev/null +++ b/benchmarks/experiment_configs/datasets/multitext.json @@ -0,0 +1,4 @@ +{ + "dataset_name": "MultiText", + "num_outputs": 33 +} diff --git a/benchmarks/experiment_configs/datasets/yearbook.json b/benchmarks/experiment_configs/datasets/yearbook.json new file mode 100644 index 00000000..51f3fced --- /dev/null +++ b/benchmarks/experiment_configs/datasets/yearbook.json @@ -0,0 +1,7 @@ +{ + "dataset_name": "yearbook", + "src_bucket": "my_bucket", + "src_object_name": "dataset/wildtime/yearbook.hdf5", + "num_inputs": 1024, + "num_outputs": 2 +} diff --git a/benchmarks/experiment_configs/fine-tuning-clear10-vitb16.json b/benchmarks/experiment_configs/fine-tuning-clear10-vitb16.json new file mode 100644 index 00000000..2ca69685 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-clear10-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-vit-b16-10updates.json", + "model": "vit-b16.json", + "updater": "fine-tuning-clear-vit-b16.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-clear100-vitb16.json b/benchmarks/experiment_configs/fine-tuning-clear100-vitb16.json new file mode 100644 index 00000000..750df2cb --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-clear100-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-vit-b16-11updates.json", + "model": "vit-b16.json", + "updater": "fine-tuning-clear-vit-b16.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-domainnet-vitb16.json b/benchmarks/experiment_configs/fine-tuning-domainnet-vitb16.json new file mode 100644 index 00000000..0ba77405 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-domainnet-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-vit-b16-6updates.json", + "model": "vit-b16.json", + "updater": "fine-tuning-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-fmow-vitb16.json b/benchmarks/experiment_configs/fine-tuning-fmow-vitb16.json new file mode 100644 index 00000000..4f34e9fa --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-fmow-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "fmow-vit-b16-16updates.json", + "model": "vit-b16.json", + "updater": "fine-tuning-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-multitext.json b/benchmarks/experiment_configs/fine-tuning-multitext.json new file mode 100644 index 00000000..af8d521a --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-multitext.json @@ -0,0 +1,6 @@ +{ + "scenario": "multitext-4updates.json", + "model": "bert.json", + "updater": "fine-tuning-multitext.json", + "dataset": "multitext.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-yearbook-vitb16.json b/benchmarks/experiment_configs/fine-tuning-yearbook-vitb16.json new file mode 100644 index 00000000..05ed64c6 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-yearbook-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-vit-b16-17updates.json", + "model": "vit-b16.json", + "updater": "fine-tuning-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/fine-tuning-yearbook.json b/benchmarks/experiment_configs/fine-tuning-yearbook.json new file mode 100644 index 00000000..d4ffa0e9 --- /dev/null +++ b/benchmarks/experiment_configs/fine-tuning-yearbook.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-17updates.json", + "model": "resnet18-cifar.json", + "updater": "fine-tuning-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/joint-cifar100-vitb16.json b/benchmarks/experiment_configs/joint-cifar100-vitb16.json new file mode 100644 index 00000000..488b7469 --- /dev/null +++ b/benchmarks/experiment_configs/joint-cifar100-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "cifar100-ci-10updates.json", + "model": "vit-b16.json", + "updater": "joint-cifar100.json", + "dataset": "cifar100.json" +} diff --git a/benchmarks/experiment_configs/joint-clear10-vitb16.json b/benchmarks/experiment_configs/joint-clear10-vitb16.json new file mode 100644 index 00000000..ac85f5be --- /dev/null +++ b/benchmarks/experiment_configs/joint-clear10-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-vit-b16-10updates.json", + "model": "vit-b16.json", + "updater": "joint-clear-vit-b16.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/joint-clear100-vitb16.json b/benchmarks/experiment_configs/joint-clear100-vitb16.json new file mode 100644 index 00000000..0b293b6f --- /dev/null +++ b/benchmarks/experiment_configs/joint-clear100-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-vit-b16-11updates.json", + "model": "vit-b16.json", + "updater": "joint-clear-vit-b16.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/joint-domainnet-vitb16.json b/benchmarks/experiment_configs/joint-domainnet-vitb16.json new file mode 100644 index 00000000..ddfb8aeb --- /dev/null +++ b/benchmarks/experiment_configs/joint-domainnet-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-vit-b16-6updates.json", + "model": "vit-b16.json", + "updater": "joint-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/joint-fmow-vitb16.json b/benchmarks/experiment_configs/joint-fmow-vitb16.json new file mode 100644 index 00000000..cd1d5c95 --- /dev/null +++ b/benchmarks/experiment_configs/joint-fmow-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "fmow-vit-b16-16updates.json", + "model": "vit-b16.json", + "updater": "joint-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/joint-multitext.json b/benchmarks/experiment_configs/joint-multitext.json new file mode 100644 index 00000000..2ceafaad --- /dev/null +++ b/benchmarks/experiment_configs/joint-multitext.json @@ -0,0 +1,6 @@ +{ + "scenario": "multitext-4updates.json", + "model": "bert.json", + "updater": "joint-multitext.json", + "dataset": "multitext.json" +} diff --git a/benchmarks/experiment_configs/joint-yearbook-vitb16.json b/benchmarks/experiment_configs/joint-yearbook-vitb16.json new file mode 100644 index 00000000..5e2146cc --- /dev/null +++ b/benchmarks/experiment_configs/joint-yearbook-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-vit-b16-17updates.json", + "model": "vit-b16.json", + "updater": "joint-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/joint-yearbook.json b/benchmarks/experiment_configs/joint-yearbook.json new file mode 100644 index 00000000..ec434eab --- /dev/null +++ b/benchmarks/experiment_configs/joint-yearbook.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-17updates.json", + "model": "resnet18-cifar.json", + "updater": "joint-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/models/vit-b16.json b/benchmarks/experiment_configs/models/vit-b16.json new file mode 100644 index 00000000..0f3c87bb --- /dev/null +++ b/benchmarks/experiment_configs/models/vit-b16.json @@ -0,0 +1,3 @@ +{ + "model_name": "VisionTransformerB16" +} diff --git a/benchmarks/experiment_configs/offline-er-arxiv.json b/benchmarks/experiment_configs/offline-er-arxiv.json new file mode 100644 index 00000000..72ad5db0 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-arxiv.json @@ -0,0 +1,6 @@ +{ + "scenario": "arxiv-16updates.json", + "model": "bert.json", + "updater": "offline-er-arxiv.json", + "dataset": "arxiv.json" +} diff --git a/benchmarks/experiment_configs/offline-er-clear10-vitb16.json b/benchmarks/experiment_configs/offline-er-clear10-vitb16.json new file mode 100644 index 00000000..0e83597e --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-clear10-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear10-vit-b16-10updates.json", + "model": "vit-b16.json", + "updater": "offline-er-clear10-vit-b16.json", + "dataset": "clear10.json" +} diff --git a/benchmarks/experiment_configs/offline-er-clear100-vitb16.json b/benchmarks/experiment_configs/offline-er-clear100-vitb16.json new file mode 100644 index 00000000..886b33d8 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-clear100-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "clear100-vit-b16-11updates.json", + "model": "vit-b16.json", + "updater": "offline-er-clear100-vit-b16.json", + "dataset": "clear100.json" +} diff --git a/benchmarks/experiment_configs/offline-er-domainnet-vitb16.json b/benchmarks/experiment_configs/offline-er-domainnet-vitb16.json new file mode 100644 index 00000000..81ffc47c --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-domainnet-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "domainnet-vit-b16-6updates.json", + "model": "vit-b16.json", + "updater": "offline-er-domainnet.json", + "dataset": "domainnet.json" +} diff --git a/benchmarks/experiment_configs/offline-er-fmow-vitb16.json b/benchmarks/experiment_configs/offline-er-fmow-vitb16.json new file mode 100644 index 00000000..ea754696 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-fmow-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "fmow-vit-b16-16updates.json", + "model": "vit-b16.json", + "updater": "offline-er-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/offline-er-fmow.json b/benchmarks/experiment_configs/offline-er-fmow.json new file mode 100644 index 00000000..7da36d10 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-fmow.json @@ -0,0 +1,6 @@ +{ + "scenario": "fmow-16updates.json", + "model": "resnet18.json", + "updater": "offline-er-fmow.json", + "dataset": "fmow.json" +} diff --git a/benchmarks/experiment_configs/offline-er-huffpost.json b/benchmarks/experiment_configs/offline-er-huffpost.json new file mode 100644 index 00000000..37fe87f6 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-huffpost.json @@ -0,0 +1,6 @@ +{ + "scenario": "huffpost-7updates.json", + "model": "bert.json", + "updater": "offline-er-huffpost.json", + "dataset": "huffpost.json" +} diff --git a/benchmarks/experiment_configs/offline-er-multitext.json b/benchmarks/experiment_configs/offline-er-multitext.json new file mode 100644 index 00000000..b52aedb2 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-multitext.json @@ -0,0 +1,6 @@ +{ + "scenario": "multitext-4updates.json", + "model": "bert.json", + "updater": "offline-er-multitext.json", + "dataset": "multitext.json" +} diff --git a/benchmarks/experiment_configs/offline-er-yearbook-vitb16.json b/benchmarks/experiment_configs/offline-er-yearbook-vitb16.json new file mode 100644 index 00000000..23fcfb37 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-yearbook-vitb16.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-vit-b16-17updates.json", + "model": "vit-b16.json", + "updater": "offline-er-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/offline-er-yearbook.json b/benchmarks/experiment_configs/offline-er-yearbook.json new file mode 100644 index 00000000..5b242b89 --- /dev/null +++ b/benchmarks/experiment_configs/offline-er-yearbook.json @@ -0,0 +1,6 @@ +{ + "scenario": "yearbook-17updates.json", + "model": "resnet18-cifar.json", + "updater": "offline-er-yearbook.json", + "dataset": "yearbook.json" +} diff --git a/benchmarks/experiment_configs/scenarios/clear10-vit-b16-10updates.json b/benchmarks/experiment_configs/scenarios/clear10-vit-b16-10updates.json new file mode 100644 index 00000000..7efc9114 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/clear10-vit-b16-10updates.json @@ -0,0 +1,14 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 10, + "max_epochs": 10, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 10, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/scenarios/clear100-vit-b16-11updates.json b/benchmarks/experiment_configs/scenarios/clear100-vit-b16-11updates.json new file mode 100644 index 00000000..813fa760 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/clear100-vit-b16-11updates.json @@ -0,0 +1,14 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 11, + "max_epochs": 10, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 10, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/scenarios/domainnet-vit-b16-6updates.json b/benchmarks/experiment_configs/scenarios/domainnet-vit-b16-6updates.json new file mode 100644 index 00000000..b7a596b4 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/domainnet-vit-b16-6updates.json @@ -0,0 +1,15 @@ +{ + "val_size": 0.3, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 6, + "max_epochs": 10, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 10, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0, + "data_ids": ["clipart", "infograph", "painting", "quickdraw", "real", "sketch"] +} diff --git a/benchmarks/experiment_configs/scenarios/fmow-vit-b16-16updates.json b/benchmarks/experiment_configs/scenarios/fmow-vit-b16-16updates.json new file mode 100644 index 00000000..8daf30b7 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/fmow-vit-b16-16updates.json @@ -0,0 +1,14 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 16, + "max_epochs": 10, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 10, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/scenarios/multitext-4updates.json b/benchmarks/experiment_configs/scenarios/multitext-4updates.json new file mode 100644 index 00000000..1bc47560 --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/multitext-4updates.json @@ -0,0 +1,10 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "num_tasks": 4, + "max_epochs": 2, + "optimizer": "AdamW", + "learning_rate": 0.00002, + "weight_decay": 0.01, + "data_ids": ["ag_news", "yelp_review_full", "dbpedia_14", "yahoo_answers_topics"] +} diff --git a/benchmarks/experiment_configs/scenarios/yearbook-17updates.json b/benchmarks/experiment_configs/scenarios/yearbook-17updates.json new file mode 100644 index 00000000..719acb1e --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/yearbook-17updates.json @@ -0,0 +1,33 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "groupings": [ + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39], + [40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59], + [60, 61, 62, 63, 64], + [65, 66, 67, 68, 69], + [70, 71, 72, 73, 74], + [75, 76, 77, 78, 79], + [80, 81, 82, 83] + ], + "num_tasks": 17, + "max_epochs": 50, + "optimizer": "SGD", + "learning_rate": 0.1, + "learning_rate_scheduler": "CosineAnnealingLR", + "learning_rate_scheduler_t_max": 50, + "learning_rate_scheduler_eta_min": 0.0001, + "learning_rate_scheduler_interval": "step", + "momentum": 0.0, + "weight_decay": 0.0 +} diff --git a/benchmarks/experiment_configs/scenarios/yearbook-vit-b16-17updates.json b/benchmarks/experiment_configs/scenarios/yearbook-vit-b16-17updates.json new file mode 100644 index 00000000..570fd88c --- /dev/null +++ b/benchmarks/experiment_configs/scenarios/yearbook-vit-b16-17updates.json @@ -0,0 +1,28 @@ +{ + "val_size": 0.1, + "scenario_name": "DataIncrementalScenario", + "groupings": [ + [0, 1, 2, 3, 4], + [5, 6, 7, 8, 9], + [10, 11, 12, 13, 14], + [15, 16, 17, 18, 19], + [20, 21, 22, 23, 24], + [25, 26, 27, 28, 29], + [30, 31, 32, 33, 34], + [35, 36, 37, 38, 39], + [40, 41, 42, 43, 44], + [45, 46, 47, 48, 49], + [50, 51, 52, 53, 54], + [55, 56, 57, 58, 59], + [60, 61, 62, 63, 64], + [65, 66, 67, 68, 69], + [70, 71, 72, 73, 74], + [75, 76, 77, 78, 79], + [80, 81, 82, 83] + ], + "num_tasks": 17, + "max_epochs": 4, + "optimizer": "AdamW", + "learning_rate": 0.00002, + "weight_decay": 0.01 +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-clear-vit-b16.json b/benchmarks/experiment_configs/updaters/fine-tuning-clear-vit-b16.json new file mode 100644 index 00000000..580b3d20 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-clear-vit-b16.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-multitext.json b/benchmarks/experiment_configs/updaters/fine-tuning-multitext.json new file mode 100644 index 00000000..09e7a594 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-multitext.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/fine-tuning-yearbook.json b/benchmarks/experiment_configs/updaters/fine-tuning-yearbook.json new file mode 100644 index 00000000..09e7a594 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/fine-tuning-yearbook.json @@ -0,0 +1,4 @@ +{ + "updater": "FineTuning", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-clear-vit-b16.json b/benchmarks/experiment_configs/updaters/joint-clear-vit-b16.json new file mode 100644 index 00000000..dbd3d97c --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-clear-vit-b16.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-multitext.json b/benchmarks/experiment_configs/updaters/joint-multitext.json new file mode 100644 index 00000000..0cfc3ad2 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-multitext.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/joint-yearbook.json b/benchmarks/experiment_configs/updaters/joint-yearbook.json new file mode 100644 index 00000000..0cfc3ad2 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/joint-yearbook.json @@ -0,0 +1,4 @@ +{ + "updater": "Joint", + "batch_size": 64 +} diff --git a/benchmarks/experiment_configs/updaters/l2p-clear-vit-b16.json b/benchmarks/experiment_configs/updaters/l2p-clear-vit-b16.json new file mode 100644 index 00000000..7cb26aba --- /dev/null +++ b/benchmarks/experiment_configs/updaters/l2p-clear-vit-b16.json @@ -0,0 +1,5 @@ +{ + "updater": "LearningToPrompt", + "batch_size": 64, + "prompt_sim_loss_weight": 0.1 +} diff --git a/benchmarks/experiment_configs/updaters/l2p-domainnet-vit-b16.json b/benchmarks/experiment_configs/updaters/l2p-domainnet-vit-b16.json new file mode 100644 index 00000000..7cb26aba --- /dev/null +++ b/benchmarks/experiment_configs/updaters/l2p-domainnet-vit-b16.json @@ -0,0 +1,5 @@ +{ + "updater": "LearningToPrompt", + "batch_size": 64, + "prompt_sim_loss_weight": 0.1 +} diff --git a/benchmarks/experiment_configs/updaters/l2p-fmow-vit-b16.json b/benchmarks/experiment_configs/updaters/l2p-fmow-vit-b16.json new file mode 100644 index 00000000..7cb26aba --- /dev/null +++ b/benchmarks/experiment_configs/updaters/l2p-fmow-vit-b16.json @@ -0,0 +1,5 @@ +{ + "updater": "LearningToPrompt", + "batch_size": 64, + "prompt_sim_loss_weight": 0.1 +} diff --git a/benchmarks/experiment_configs/updaters/l2p-yearbook.json b/benchmarks/experiment_configs/updaters/l2p-yearbook.json new file mode 100644 index 00000000..48ae4bc6 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/l2p-yearbook.json @@ -0,0 +1,5 @@ +{ + "updater": "LearningToPrompt", + "batch_size": 64, + "prompt_sim_loss_weight": 0.1 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear10-vit-b16.json b/benchmarks/experiment_configs/updaters/offline-er-clear10-vit-b16.json new file mode 100644 index 00000000..8140d002 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-clear10-vit-b16.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 64, + "batch_memory_frac": 0.5, + "memory_size": 10000 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear100-vit-b16.json b/benchmarks/experiment_configs/updaters/offline-er-clear100-vit-b16.json new file mode 100644 index 00000000..957a0c57 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-clear100-vit-b16.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 64, + "batch_memory_frac": 0.5, + "memory_size": 99999999 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-fmow.json b/benchmarks/experiment_configs/updaters/offline-er-fmow.json new file mode 100644 index 00000000..edf91f37 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-fmow.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 64, + "batch_memory_frac": 0.5, + "memory_size": 11468 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-huffpost.json b/benchmarks/experiment_configs/updaters/offline-er-huffpost.json new file mode 100644 index 00000000..9f2e92ab --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-huffpost.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 32, + "batch_memory_frac": 0.5, + "memory_size": 5751 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-multitext.json b/benchmarks/experiment_configs/updaters/offline-er-multitext.json new file mode 100644 index 00000000..e137594b --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-multitext.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 64, + "batch_memory_frac": 0.5, + "memory_size": 11500 +} diff --git a/benchmarks/experiment_configs/updaters/offline-er-yearbook.json b/benchmarks/experiment_configs/updaters/offline-er-yearbook.json new file mode 100644 index 00000000..695ece29 --- /dev/null +++ b/benchmarks/experiment_configs/updaters/offline-er-yearbook.json @@ -0,0 +1,6 @@ +{ + "updater": "Offline-ER", + "batch_size": 64, + "batch_memory_frac": 0.5, + "memory_size": 3343 +} diff --git a/doc/_images/improvement_renate.svg b/doc/_images/improvement_renate.svg index 68557d26..3b4805a1 100644 --- a/doc/_images/improvement_renate.svg +++ b/doc/_images/improvement_renate.svg @@ -6,11 +6,11 @@ - 2022-11-24T15:48:23.916169 + 2023-09-25T10:14:27.226683 image/svg+xml - Matplotlib v3.6.0, https://matplotlib.org/ + Matplotlib v3.5.0, https://matplotlib.org/ @@ -30,108 +30,108 @@ z - - +" clip-path="url(#p8a163cc207)" style="fill: #1f77b4"/> - +" clip-path="url(#p8a163cc207)" style="fill: #1f77b4"/> - +" clip-path="url(#p8a163cc207)" style="fill: #1f77b4"/> - +" clip-path="url(#p8a163cc207)" style="fill: #1f77b4"/> - +" clip-path="url(#p8a163cc207)" style="fill: #1f77b4"/> - +" clip-path="url(#p8a163cc207)" style="fill: #ff7f0e"/> - +" clip-path="url(#p8a163cc207)" style="fill: #ff7f0e"/> - +" clip-path="url(#p8a163cc207)" style="fill: #ff7f0e"/> - +" clip-path="url(#p8a163cc207)" style="fill: #ff7f0e"/> - +" clip-path="url(#p8a163cc207)" style="fill: #ff7f0e"/> - - + - + - + - + - + - + - + - + - + - + - + - - + - - + + - + @@ -786,26 +786,81 @@ z - + - - + + + + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - + + + + - + - - - - - - + + - - - - + + - - + + - - + + + + + - + - - + + - - - + + + - + + - - + + - - + - - + + - - - - + + + + - - + + - - - - + + + + + + + - + - - - - + + + + - - + + - - - - + + + + - - + + - - - - + + + + - - + + - - + + - + - - + + - - + + + - - - - - - - - - - + + + + + + + + + + + - - - + @@ -1324,16 +1344,16 @@ z - - + @@ -1346,8 +1366,8 @@ z - - + + diff --git a/doc/requirements.txt b/doc/requirements.txt index 77449f40..feb0cd8e 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,11 +1,11 @@ -docutils==0.19 -Sphinx==7.1.0 +docutils==0.20.1 +Sphinx==7.2.6 sphinx-copybutton==0.5.2 sphinx-hoverxref==1.3.0 sphinxext-opengraph==0.8.2 -pydata-sphinx-theme==0.13.3 -sphinx-autodoc-typehints==1.23.0 -sphinx-paramlinks==0.5.4 +pydata-sphinx-theme==0.14.0 +sphinx-autodoc-typehints==1.24.0 +sphinx-paramlinks==0.6.0 # Temporarily added avalanche_lib==0.3.1 diff --git a/requirements.txt b/requirements.txt index e741917e..241adf93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,17 @@ -numpy>=1.17.2, <1.25.3 +numpy>=1.17.2, <1.26.1 torch>=1.10.0, <1.13.2 pandas>=1.4.0, <2.1.1 -boto3>=1.26.0, <1.26.139 +boto3>=1.26.0, <1.28.51 requests>=2.31.0, <2.31.1 -sagemaker>=2.112.0, <2.158.1 -syne-tune[aws,gpsearchers]==0.6.0 +sagemaker>=2.112.0, <2.186.1 +syne-tune[aws,gpsearchers]>=0.6.0, <0.9.2 pytorch-lightning>=1.8.0, <1.9.5 -Pillow>=9.0, <9.5.1 +Pillow>=9.0, <10.0.2 tabulate>=0.9.0, <0.9.1 tensorboardX>=2.5.0, <2.6.3 torchmetrics>=0.11.0, <0.11.5 -torchvision>=0.13.0, <0.15.2 -deepspeed==0.9.1 -datasets>=2.9.0, <2.14.1 -transformers>=4.31.0, <4.31.1 +torchvision>=0.13.0, <0.15.3 +deepspeed>=0.9.0, <0.10.4 +datasets>=2.9.0, <2.14.6 +transformers>=4.31.0, <4.33.3 scipy>=1.9.0, <1.11.3 diff --git a/src/renate/__init__.py b/src/renate/__init__.py index a6ccdc9d..0a7410a7 100644 --- a/src/renate/__init__.py +++ b/src/renate/__init__.py @@ -14,4 +14,4 @@ _renate_logger.addHandler(_handler) _renate_logger.propagate = False -__version__ = "0.3.1" +__version__ = "0.4.0"