From 5d0962beadfff2284ec0f520ed302ebba2b8c61e Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Tue, 11 Jul 2023 14:17:59 +0200 Subject: [PATCH 1/2] fix offline-er and add tests --- .../updaters/experimental/offline_er.py | 18 +++++----- src/renate/utils/pytorch.py | 36 ++++++++++++++++++- test/renate/utils/test_pytorch.py | 33 ++++++++++++++++- 3 files changed, 77 insertions(+), 10 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 49b03bda..066b3a52 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -17,7 +17,11 @@ from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater -from renate.utils.pytorch import move_tensors_to_device +from renate.utils.pytorch import ( + cat_nested_tensors, + get_shape_nested_tensors, + move_tensors_to_device, +) class OfflineExperienceReplayLearner(ReplayLearner): @@ -96,9 +100,7 @@ def on_model_update_end(self) -> None: self._num_points_previous_tasks += self._num_points_current_task self._num_points_current_task = -1 - def training_step( - self, batch: Dict[str, Tuple[NestedTensors, torch.Tensor]], batch_idx: int - ) -> STEP_OUTPUT: + def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" if self._loss_weight_new_data is None: alpha = self._num_points_current_task / ( @@ -107,13 +109,13 @@ def training_step( else: alpha = self._loss_weight_new_data inputs, targets = batch["current_task"] - device = inputs.device - batch_size_current = inputs.shape[0] + device = next(self.parameters()).device + batch_size_current = get_shape_nested_tensors(inputs)[0] batch_size_mem = 0 if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] - batch_size_mem = inputs_mem.shape[0] - inputs = torch.cat((inputs, inputs_mem), 0) + batch_size_mem = get_shape_nested_tensors(inputs_mem)[0] + inputs = cat_nested_tensors((inputs, inputs_mem), 0) targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) loss = self._loss_fn(outputs, targets) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 6f130236..f1cdc48a 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch from torch.utils.data import Dataset, random_split @@ -87,3 +87,37 @@ def move_tensors_to_device(tensors: NestedTensors, device: torch.device) -> Nest "Expected `tensors` to be a nested structure of tensors, tuples, list and dict; " f"discovered {type(tensors)}." ) + + +def get_shape_nested_tensors(batch: NestedTensors) -> torch.Size: + """Given a NestedTensor, return its batch size.""" + if isinstance(batch, torch.Tensor): + return batch.shape + if isinstance(batch, tuple): + return batch[0].shape + if isinstance(batch, dict): + return batch[next(iter(batch.keys()))].shape + + +def cat_nested_tensors( + nested_tensors: Union[Tuple[NestedTensors], List[NestedTensors]], axis: int = 0 +) -> NestedTensors: + """Concatenates the two NestedTensors. + + Equivalent of PyTorch's ``cat`` function for ``NestedTensors``. + + Args: + nested_tensors: Tensors to be concatenated. + axis: Concatenation axis. + """ + if isinstance(nested_tensors[0], torch.Tensor): + return torch.cat(nested_tensors, axis) + if isinstance(nested_tensors[0], tuple): + return tuple( + cat_nested_tensors(nested_tensor, axis) for nested_tensor in zip(*nested_tensors) + ) + if isinstance(nested_tensors[0], dict): + return { + key: cat_nested_tensors([nested_tensor[key] for nested_tensor in nested_tensors], axis) + for key in nested_tensors[0] + } diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index 6ae2fdde..0fdc3090 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -6,7 +6,7 @@ from torch.utils.data import TensorDataset from renate.utils import pytorch -from renate.utils.pytorch import randomly_split_data +from renate.utils.pytorch import cat_nested_tensors, get_shape_nested_tensors, randomly_split_data @pytest.mark.parametrize("model", [torchvision.models.resnet18(pretrained=True)]) @@ -61,3 +61,34 @@ def test_random_splitting_sample_split_with_same_random_seed(): for i in range(5): assert torch.equal(d_1_split_1[i][0], d_2_split_1[i][0]) assert torch.equal(d_1_split_2[i][0], d_2_split_2[i][0]) + + +def test_get_shape_nested_tensors(): + expected_batch_size = 10 + t = torch.zeros(expected_batch_size) + assert get_shape_nested_tensors(t)[0] == expected_batch_size + tuple_tensor = (t, t) + assert get_shape_nested_tensors(tuple_tensor)[0] == expected_batch_size + dict_tensor = {"k1": t, "k2": t} + assert get_shape_nested_tensors(dict_tensor)[0] == expected_batch_size + + +def test_cat_nested_tensors(): + tensor_dim = 2 + zeros = torch.zeros(tensor_dim) + ones = torch.ones(tensor_dim) + expected_mean = 0.5 + expected_batch_size = 2 * tensor_dim + result = cat_nested_tensors((zeros, ones)) + assert get_shape_nested_tensors(result)[0] == expected_batch_size + assert result.mean() == expected_mean + tuple_tensor = (zeros, ones) + result = cat_nested_tensors((tuple_tensor, tuple_tensor)) + assert get_shape_nested_tensors(result)[0] == expected_batch_size + assert result[0].sum() == 0 + assert result[1].sum() == 2 * tensor_dim + dict_tensor = {"zeros": zeros, "ones": ones} + result = cat_nested_tensors((dict_tensor, dict_tensor)) + assert get_shape_nested_tensors(result)[0] == expected_batch_size + assert result["zeros"].sum() == 0 + assert result["ones"].sum() == 2 * tensor_dim From 9de3484ecb459d04eff6da6c3dc2d612d25b571d Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Mon, 24 Jul 2023 16:33:26 +0200 Subject: [PATCH 2/2] improve tests and docstring --- .../updaters/experimental/offline_er.py | 6 +-- src/renate/utils/pytorch.py | 13 +++--- test/renate/utils/test_pytorch.py | 40 ++++++++++++------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 066b3a52..219a8c3c 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -19,7 +19,7 @@ from renate.updaters.model_updater import SingleTrainingLoopUpdater from renate.utils.pytorch import ( cat_nested_tensors, - get_shape_nested_tensors, + get_length_nested_tensors, move_tensors_to_device, ) @@ -110,11 +110,11 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) alpha = self._loss_weight_new_data inputs, targets = batch["current_task"] device = next(self.parameters()).device - batch_size_current = get_shape_nested_tensors(inputs)[0] + batch_size_current = get_length_nested_tensors(inputs) batch_size_mem = 0 if "memory" in batch: (inputs_mem, targets_mem), _ = batch["memory"] - batch_size_mem = get_shape_nested_tensors(inputs_mem)[0] + batch_size_mem = get_length_nested_tensors(inputs_mem) inputs = cat_nested_tensors((inputs, inputs_mem), 0) targets = torch.cat((targets, targets_mem), 0) outputs = self(inputs) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index f1cdc48a..0765e2cc 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -89,14 +89,17 @@ def move_tensors_to_device(tensors: NestedTensors, device: torch.device) -> Nest ) -def get_shape_nested_tensors(batch: NestedTensors) -> torch.Size: - """Given a NestedTensor, return its batch size.""" +def get_length_nested_tensors(batch: NestedTensors) -> torch.Size: + """Given a NestedTensor, return its length. + + Assumes that the first axis in each element is the same. + """ if isinstance(batch, torch.Tensor): - return batch.shape + return batch.shape[0] if isinstance(batch, tuple): - return batch[0].shape + return batch[0].shape[0] if isinstance(batch, dict): - return batch[next(iter(batch.keys()))].shape + return batch[next(iter(batch.keys()))].shape[0] def cat_nested_tensors( diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index 0fdc3090..b247976a 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -6,7 +6,7 @@ from torch.utils.data import TensorDataset from renate.utils import pytorch -from renate.utils.pytorch import cat_nested_tensors, get_shape_nested_tensors, randomly_split_data +from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors, randomly_split_data @pytest.mark.parametrize("model", [torchvision.models.resnet18(pretrained=True)]) @@ -63,32 +63,42 @@ def test_random_splitting_sample_split_with_same_random_seed(): assert torch.equal(d_1_split_2[i][0], d_2_split_2[i][0]) -def test_get_shape_nested_tensors(): +def test_get_length_nested_tensors(): expected_batch_size = 10 t = torch.zeros(expected_batch_size) - assert get_shape_nested_tensors(t)[0] == expected_batch_size + assert get_length_nested_tensors(t) == expected_batch_size tuple_tensor = (t, t) - assert get_shape_nested_tensors(tuple_tensor)[0] == expected_batch_size + assert get_length_nested_tensors(tuple_tensor) == expected_batch_size dict_tensor = {"k1": t, "k2": t} - assert get_shape_nested_tensors(dict_tensor)[0] == expected_batch_size + assert get_length_nested_tensors(dict_tensor) == expected_batch_size def test_cat_nested_tensors(): tensor_dim = 2 - zeros = torch.zeros(tensor_dim) - ones = torch.ones(tensor_dim) - expected_mean = 0.5 - expected_batch_size = 2 * tensor_dim + first_dim_ones = 8 + zeros = torch.zeros((2, tensor_dim)) + ones = torch.ones((first_dim_ones, tensor_dim)) result = cat_nested_tensors((zeros, ones)) - assert get_shape_nested_tensors(result)[0] == expected_batch_size - assert result.mean() == expected_mean + assert get_length_nested_tensors(result) == 10 + assert result.mean() == 0.8 tuple_tensor = (zeros, ones) result = cat_nested_tensors((tuple_tensor, tuple_tensor)) - assert get_shape_nested_tensors(result)[0] == expected_batch_size + assert get_length_nested_tensors(result) == 4 assert result[0].sum() == 0 - assert result[1].sum() == 2 * tensor_dim + assert result[1].sum() == 2 * first_dim_ones * tensor_dim dict_tensor = {"zeros": zeros, "ones": ones} result = cat_nested_tensors((dict_tensor, dict_tensor)) - assert get_shape_nested_tensors(result)[0] == expected_batch_size + assert get_length_nested_tensors(result) == 4 assert result["zeros"].sum() == 0 - assert result["ones"].sum() == 2 * tensor_dim + assert result["ones"].sum() == 2 * first_dim_ones * tensor_dim + + +def test_cat_nested_tensors_wrong_shape(): + tensor1 = torch.zeros((2, 2)) + tensor2 = torch.zeros((2, 3)) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors((tensor1, tensor2)) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors(((tensor1, tensor1), (tensor1, tensor2))) + with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"): + cat_nested_tensors(({"k1": tensor1, "k2": tensor1}, {"k1": tensor1, "k2": tensor2}))