From e5f0d3ebd2ab1cbd5774c2725d40757256c65944 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 17:08:09 +0200 Subject: [PATCH 01/13] changes to Offline-ER without drop_last=True in learner --- src/renate/memory/buffer.py | 4 +- .../updaters/experimental/offline_er.py | 56 ++++++++++--------- src/renate/utils/pytorch.py | 33 ++++++++++- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/src/renate/memory/buffer.py b/src/renate/memory/buffer.py index 7b762a27..57d75d22 100644 --- a/src/renate/memory/buffer.py +++ b/src/renate/memory/buffer.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy from collections import defaultdict -from typing import Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch from torch.utils.data import Dataset @@ -67,7 +67,7 @@ def __len__(self) -> int: """Returns the current length of the buffer.""" return len(self._indices) - def __getitem__(self, idx: int) -> NestedTensors: + def __getitem__(self, idx: int) -> Tuple[NestedTensors, Dict[str, Any]]: """Reads the item at index `idx` of the buffer.""" i, j = self._indices[idx] data = self._datasets[i][j] diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index c541faea..335ece94 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -6,19 +6,19 @@ import torch import torchmetrics from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn import Parameter from torch.optim import Optimizer -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import ConcatDataset, DataLoader, Dataset from renate import defaults +from renate.memory import ReservoirBuffer from renate.models import RenateModule from renate.types import NestedTensors from renate.updaters.learner import ReplayLearner from renate.updaters.model_updater import SingleTrainingLoopUpdater from renate.utils.misc import maybe_populate_mask_and_ignore_logits -from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors +from renate.utils.pytorch import ConcatRandomSampler class OfflineExperienceReplayLearner(ReplayLearner): @@ -71,24 +71,28 @@ def on_model_update_start( self._num_points_current_task = len(train_dataset) def train_dataloader(self) -> DataLoader: - loaders = {} if len(self._memory_buffer) > self._memory_batch_size: - loaders["current_task"] = super().train_dataloader() - loaders["memory"] = DataLoader( - dataset=self._memory_buffer, - batch_size=self._memory_batch_size, - drop_last=True, - shuffle=True, + train_buffer = ReservoirBuffer( + max_size=self._num_points_current_task, + seed=0, + transform=self._train_transform, + target_transform=self._train_target_transform, + ) + train_buffer.update(self._train_dataset) + return DataLoader( + dataset=ConcatDataset([train_buffer, self._memory_buffer]), generator=self._rng, pin_memory=True, collate_fn=self._train_collate_fn, + batch_sampler=ConcatRandomSampler( + [self._num_points_current_task, len(self._memory_buffer)], + [self._batch_size, self._memory_batch_size], + generator=self._rng, + ), ) - else: - batch_size = self._batch_size - self._batch_size += self._memory_batch_size - loaders["current_task"] = super().train_dataloader() - self._batch_size = batch_size - return CombinedLoader(loaders, mode="max_size_cycle") + self._batch_size += self._memory_batch_size + self._memory_batch_size = 0 + return super().train_dataloader() def on_model_update_end(self) -> None: """Called right before a model update terminates.""" @@ -96,7 +100,9 @@ def on_model_update_end(self) -> None: self._num_points_previous_tasks += self._num_points_current_task self._num_points_current_task = -1 - def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT: + def training_step( + self, batch: Tuple[NestedTensors, Dict[str, Any]], batch_idx: int + ) -> STEP_OUTPUT: """PyTorch Lightning function to return the training loss.""" if self._loss_weight_new_data is None: alpha = self._num_points_current_task / ( @@ -105,21 +111,19 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) else: alpha = self._loss_weight_new_data alpha = torch.tensor(alpha, device=next(self.parameters()).device) - inputs, targets = batch["current_task"] - batch_size_current = get_length_nested_tensors(inputs) - if "memory" in batch: - (inputs_mem, targets_mem), _ = batch["memory"] - inputs = cat_nested_tensors((inputs, inputs_mem), 0) - targets = torch.cat((targets, targets_mem), 0) + if self._memory_batch_size: + (inputs, targets), _ = batch + else: + inputs, targets = batch outputs = self(inputs) outputs, self._class_mask = maybe_populate_mask_and_ignore_logits( self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs ) loss = self._loss_fn(outputs, targets) - if "memory" in batch: - loss_current = loss[:batch_size_current].mean() - loss_memory = loss[batch_size_current:].mean() + if self._memory_batch_size: + loss_current = loss[: self._batch_size].mean() + loss_memory = loss[self._batch_size :].mean() self._loss_collections["train_losses"]["base_loss"](loss_current) self._loss_collections["train_losses"]["memory_loss"](loss_memory) loss = alpha * loss_current + (1.0 - alpha) * loss_memory diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 30a5738d..10481b37 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import List, Optional, Set, Tuple, Union +from typing import Iterator, List, Optional, Set, Tuple, Union import torch -from torch.utils.data import Dataset, random_split +from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split from transformers import BatchEncoding from renate import defaults @@ -156,3 +156,32 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int valid_classes: A set of integers of valid classes. """ return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] + + +class ConcatRandomSampler(Sampler[int]): + """Sampler for sampling batches from ConcatDatasets. + + Args: + dataset_lengths: The length for the different datasets. + batch_sizes: Batch sizes used for specific datasets. + generator (Generator): Generator used in sampling. + """ + + def __init__(self, dataset_lengths, batch_sizes, generator=None) -> None: + self.batch_sizes = batch_sizes + self.subset_samplers = [] + start_idx = 0 + for dataset_length, batch_size in zip(dataset_lengths, batch_sizes): + end_idx = start_idx + dataset_length + self.subset_samplers.append( + BatchSampler( + SubsetRandomSampler(list(range(start_idx, end_idx)), generator), + batch_size, + True, + ) + ) + start_idx = end_idx + + def __iter__(self) -> Iterator[List[int]]: + for samples in zip(*self.subset_samplers): + yield [j for i in samples for j in i] From 542c855b84b6f15b0ac8dd3fb645253b14c90126 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 18:17:45 +0200 Subject: [PATCH 02/13] fix sampler --- .../updaters/experimental/offline_er.py | 1 + src/renate/utils/pytorch.py | 37 +++++++++++++++++-- .../configs/suites/quick/offline-er.json | 2 +- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 335ece94..84364634 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -87,6 +87,7 @@ def train_dataloader(self) -> DataLoader: batch_sampler=ConcatRandomSampler( [self._num_points_current_task, len(self._memory_buffer)], [self._batch_size, self._memory_batch_size], + 0, generator=self._rng, ), ) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 10481b37..54da9482 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -158,19 +158,24 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] -class ConcatRandomSampler(Sampler[int]): +class ConcatRandomSampler(BatchSampler): """Sampler for sampling batches from ConcatDatasets. Args: dataset_lengths: The length for the different datasets. batch_sizes: Batch sizes used for specific datasets. + complete_dataset_iteration: Provide an index to indicate over which dataset to fully iterate. By default, stops whenever iteration is complete for any dataset. generator (Generator): Generator used in sampling. """ - def __init__(self, dataset_lengths, batch_sizes, generator=None) -> None: + def __init__( + self, dataset_lengths, batch_sizes, complete_dataset_iteration=None, generator=None + ) -> None: self.batch_sizes = batch_sizes + self.complete_dataset_iteration = complete_dataset_iteration self.subset_samplers = [] start_idx = 0 + num_batches = [] for dataset_length, batch_size in zip(dataset_lengths, batch_sizes): end_idx = start_idx + dataset_length self.subset_samplers.append( @@ -180,8 +185,32 @@ def __init__(self, dataset_lengths, batch_sizes, generator=None) -> None: True, ) ) + num_batches.append(dataset_length // batch_size) start_idx = end_idx + self.length = ( + min(num_batches) + if complete_dataset_iteration is None + else num_batches[self.complete_dataset_iteration] + ) def __iter__(self) -> Iterator[List[int]]: - for samples in zip(*self.subset_samplers): - yield [j for i in samples for j in i] + if self.complete_dataset_iteration is None: + for samples in zip(*self.subset_samplers): + yield [j for i in samples for j in i] + else: + iterators = [iter(sampler) for sampler in self.subset_samplers] + for s in iterators[self.complete_dataset_iteration]: + samples = [] + for i, iterator in enumerate(iterators): + if i != self.complete_dataset_iteration: + try: + samples.append(next(iterator)) + except StopIteration: + iterators[i] = iter(self.subset_samplers[i]) + samples.append(next(iterators[i])) + else: + samples.append(s) + yield [j for i in samples for j in i] + + def __len__(self): + return self.length diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 3732b00e..6ba48f8f 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -6,5 +6,5 @@ "backend": "local", "job_name": "class-incremental-mlp-offline-er", "expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]], - "expected_accuracy_darwin": [[0.7315000295639038, 0.49000000953674316]] + "expected_accuracy_darwin": [[0.7279999852180481, 0.4650000035762787]] } From c23014da14ddf5b4de3ce3dbf7c5b046a8f6b8ff Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 18:29:39 +0200 Subject: [PATCH 03/13] update expected values --- test/integration_tests/configs/suites/quick/offline-er.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 6ba48f8f..6682ed5e 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]], + "expected_accuracy_linux": [[0.7634999752044678, 0.40299999713897705], [0.6514999866485596, 0.3725000023841858]], "expected_accuracy_darwin": [[0.7279999852180481, 0.4650000035762787]] } From 0f8803ad5244c0398f5e3dab73bb4da0d4f1eb9d Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 18:30:11 +0200 Subject: [PATCH 04/13] flake --- src/renate/utils/pytorch.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 54da9482..68adaf28 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -5,7 +5,7 @@ from typing import Iterator, List, Optional, Set, Tuple, Union import torch -from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split +from torch.utils.data import BatchSampler, Dataset, SubsetRandomSampler, random_split from transformers import BatchEncoding from renate import defaults @@ -164,7 +164,8 @@ class ConcatRandomSampler(BatchSampler): Args: dataset_lengths: The length for the different datasets. batch_sizes: Batch sizes used for specific datasets. - complete_dataset_iteration: Provide an index to indicate over which dataset to fully iterate. By default, stops whenever iteration is complete for any dataset. + complete_dataset_iteration: Provide an index to indicate over which dataset to fully iterate. By default, stops + whenever iteration is complete for any dataset. generator (Generator): Generator used in sampling. """ From 546b8eb34251f1f3637af6cd32a8f133ccb17aff Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 18:37:12 +0200 Subject: [PATCH 05/13] flake --- src/renate/utils/pytorch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 68adaf28..4dc45bae 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -20,7 +20,7 @@ def reinitialize_model_parameters(model: torch.nn.Module) -> None: implementations of exotic layers. A warning is logged for modules that do not implement `reset_parameters()`. - The actual logic of renitializing parameters depends on the type of layer. It may affect the + The actual logic of reintializing parameters depends on the type of layer. It may affect the module's buffers (non-trainable parameters, e.g., batch norm stats) as well. Args: @@ -164,8 +164,8 @@ class ConcatRandomSampler(BatchSampler): Args: dataset_lengths: The length for the different datasets. batch_sizes: Batch sizes used for specific datasets. - complete_dataset_iteration: Provide an index to indicate over which dataset to fully iterate. By default, stops - whenever iteration is complete for any dataset. + complete_dataset_iteration: Provide an index to indicate over which dataset to fully + iterate. By default, stops whenever iteration is complete for any dataset. generator (Generator): Generator used in sampling. """ From 8cea500457f784877a0e936d4de4f94a3fbe63f8 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Wed, 23 Aug 2023 18:53:08 +0200 Subject: [PATCH 06/13] update expected numbers and add unit test --- .../configs/suites/quick/offline-er.json | 2 +- test/renate/utils/test_pytorch.py | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/integration_tests/configs/suites/quick/offline-er.json b/test/integration_tests/configs/suites/quick/offline-er.json index 6682ed5e..3ab5cfc4 100644 --- a/test/integration_tests/configs/suites/quick/offline-er.json +++ b/test/integration_tests/configs/suites/quick/offline-er.json @@ -5,6 +5,6 @@ "dataset": "cifar10.json", "backend": "local", "job_name": "class-incremental-mlp-offline-er", - "expected_accuracy_linux": [[0.7634999752044678, 0.40299999713897705], [0.6514999866485596, 0.3725000023841858]], + "expected_accuracy_linux": [[0.7634999752044678, 0.40299999713897705], [0.6234999895095825, 0.3779999911785126]], "expected_accuracy_darwin": [[0.7279999852180481, 0.4650000035762787]] } diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index f96bcfcd..73f02ccd 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -11,6 +11,7 @@ from renate.memory.buffer import ReservoirBuffer from renate.utils import pytorch from renate.utils.pytorch import ( + ConcatRandomSampler, cat_nested_tensors, complementary_indices, get_length_nested_tensors, @@ -150,3 +151,22 @@ def test_unique_classes(tmpdir, test_dataset): buffer.update(ds, metadata) predicted_unique = unique_classes(buffer) assert predicted_unique == set(list(range(10))) + + +@pytest.mark.parametrize( + "complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]] +) +def test_concat_random_sampler(complete_dataset_iteration, expected_batches): + sampler = ConcatRandomSampler( + dataset_lengths=[15, 5, 20], + batch_sizes=[2, 1, 8], + complete_dataset_iteration=complete_dataset_iteration, + ) + assert len(sampler) == expected_batches + num_batches = 0 + for sample in sampler: + assert all([s < 15 for s in sample[:2]]) + assert all([15 <= s < 20 for s in sample[2:3]]) + assert all([20 <= s < 40 for s in sample[3:]]) + num_batches += 1 + assert num_batches == expected_batches From 0b07ac3ea921941b96e407181c33794a3502e33e Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Thu, 24 Aug 2023 10:22:57 +0200 Subject: [PATCH 07/13] small code improvements --- src/renate/utils/pytorch.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 4dc45bae..229b9ee8 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -195,23 +195,26 @@ def __init__( ) def __iter__(self) -> Iterator[List[int]]: + """Creates a batch with groups of indices where each group corresponds to one dataset.""" if self.complete_dataset_iteration is None: + # Default case is iterating once over the shortest iterator. Works nicely with zip. for samples in zip(*self.subset_samplers): yield [j for i in samples for j in i] else: + # Iterating over a specific iterator requires dealing with the length of other iterators. iterators = [iter(sampler) for sampler in self.subset_samplers] for s in iterators[self.complete_dataset_iteration]: samples = [] for i, iterator in enumerate(iterators): if i != self.complete_dataset_iteration: try: - samples.append(next(iterator)) + samples += next(iterator) except StopIteration: iterators[i] = iter(self.subset_samplers[i]) - samples.append(next(iterators[i])) + samples += next(iterators[i]) else: - samples.append(s) - yield [j for i in samples for j in i] + samples += s + yield samples def __len__(self): return self.length From 8190e26ecfd0e5965ac6915fd7cbc7613fe97345 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Tue, 19 Sep 2023 18:31:47 +0200 Subject: [PATCH 08/13] ddp compatible sampler --- src/renate/updaters/learner.py | 2 +- src/renate/utils/pytorch.py | 34 +++++++++++++++++++++------------- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index d7e8777f..3180d034 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -153,7 +153,7 @@ def on_model_update_start( ) -> None: self._train_dataset = train_dataset self._val_dataset = val_dataset - self.val_enabled = val_dataset is not None and len(val_dataset) + self.val_enabled = val_dataset is not None and len(val_dataset) > 0 self._train_collate_fn = train_dataset_collate_fn self._val_collate_fn = val_dataset_collate_fn self._task_id = task_id diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 229b9ee8..8b0eff05 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -20,11 +20,11 @@ def reinitialize_model_parameters(model: torch.nn.Module) -> None: implementations of exotic layers. A warning is logged for modules that do not implement `reset_parameters()`. - The actual logic of reintializing parameters depends on the type of layer. It may affect the + The actual logic of reinitializing parameters depends on the type of layer. It may affect the module's buffers (non-trainable parameters, e.g., batch norm stats) as well. Args: - model: The model to be re-initialized. + model: The model to be reinitialized. """ for module in model.modules(): # Skip modules without any parameters of their own. @@ -170,24 +170,32 @@ class ConcatRandomSampler(BatchSampler): """ def __init__( - self, dataset_lengths, batch_sizes, complete_dataset_iteration=None, generator=None + self, + dataset_lengths, + batch_sizes, + complete_dataset_iteration=None, + generator=None, + sampler=None, ) -> None: self.batch_sizes = batch_sizes self.complete_dataset_iteration = complete_dataset_iteration self.subset_samplers = [] - start_idx = 0 + data_start_idx = 0 num_batches = [] + rank = 0 if sampler is None else sampler.rank + num_replicas = 1 if sampler is None else sampler.num_replicas for dataset_length, batch_size in zip(dataset_lengths, batch_sizes): - end_idx = start_idx + dataset_length - self.subset_samplers.append( - BatchSampler( - SubsetRandomSampler(list(range(start_idx, end_idx)), generator), - batch_size, - True, - ) + data_end_idx = data_start_idx + dataset_length + start_idx = data_start_idx + round(dataset_length / num_replicas * rank) + end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1)) + subset_sampler = BatchSampler( + SubsetRandomSampler(list(range(start_idx, end_idx)), generator), + batch_size, + True, ) - num_batches.append(dataset_length // batch_size) - start_idx = end_idx + self.subset_samplers.append(subset_sampler) + num_batches.append((end_idx - start_idx + 1) // batch_size) + data_start_idx = data_end_idx self.length = ( min(num_batches) if complete_dataset_iteration is None From 8c69db7e9cc159725a605d1cfa50fbd20a618e10 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Tue, 19 Sep 2023 18:45:50 +0200 Subject: [PATCH 09/13] lint --- src/renate/utils/pytorch.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 8b0eff05..dab95b3b 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -209,7 +209,8 @@ def __iter__(self) -> Iterator[List[int]]: for samples in zip(*self.subset_samplers): yield [j for i in samples for j in i] else: - # Iterating over a specific iterator requires dealing with the length of other iterators. + # Iterating over a specific iterator requires dealing with the length of other + # iterators. iterators = [iter(sampler) for sampler in self.subset_samplers] for s in iterators[self.complete_dataset_iteration]: samples = [] From 40ffc8813fa0c177fe7988bce0b8093b8b5293b3 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Tue, 19 Sep 2023 19:12:04 +0200 Subject: [PATCH 10/13] fix small bug --- src/renate/utils/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index dab95b3b..3fe85e0a 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -194,7 +194,7 @@ def __init__( True, ) self.subset_samplers.append(subset_sampler) - num_batches.append((end_idx - start_idx + 1) // batch_size) + num_batches.append((end_idx - start_idx) // batch_size) data_start_idx = data_end_idx self.length = ( min(num_batches) From a3bf468a34abcc8f511ed5f88334ceba3666f5c4 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Tue, 19 Sep 2023 20:00:16 +0200 Subject: [PATCH 11/13] add typing --- src/renate/utils/pytorch.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index 3fe85e0a..fbf95228 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -2,10 +2,10 @@ # SPDX-License-Identifier: Apache-2.0 import logging import math -from typing import Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Iterator, List, Optional, Set, Tuple, Union import torch -from torch.utils.data import BatchSampler, Dataset, SubsetRandomSampler, random_split +from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split from transformers import BatchEncoding from renate import defaults @@ -166,16 +166,18 @@ class ConcatRandomSampler(BatchSampler): batch_sizes: Batch sizes used for specific datasets. complete_dataset_iteration: Provide an index to indicate over which dataset to fully iterate. By default, stops whenever iteration is complete for any dataset. - generator (Generator): Generator used in sampling. + generator: Generator used in sampling. + sampler: Lightning automatically passes a DistributedSamplerWrapper. Only used as an + indicator that we are in the distributed case. """ def __init__( self, - dataset_lengths, - batch_sizes, - complete_dataset_iteration=None, - generator=None, - sampler=None, + dataset_lengths: List[int], + batch_sizes: List[int], + complete_dataset_iteration: Optional[int] = None, + generator: Any = None, + sampler: Sampler = None, ) -> None: self.batch_sizes = batch_sizes self.complete_dataset_iteration = complete_dataset_iteration From 2746391341d21fdec5673852a4ad2f3aaaae6938 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Thu, 21 Sep 2023 18:39:11 +0200 Subject: [PATCH 12/13] provide better docs for batch sampler --- src/renate/utils/pytorch.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/renate/utils/pytorch.py b/src/renate/utils/pytorch.py index fbf95228..c143765a 100644 --- a/src/renate/utils/pytorch.py +++ b/src/renate/utils/pytorch.py @@ -158,9 +158,25 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] -class ConcatRandomSampler(BatchSampler): +class ConcatRandomSampler(Sampler[List[int]]): """Sampler for sampling batches from ConcatDatasets. + Each sampled batch is composed of batches of different BatchSamplers with the specified + batch sizes and ranges. + + To clarify the behavior, we provide a little example. + ``dataset_lengths = [5, 2]`` + ``batch_sizes = [3, 1]`` + + With this setting, we have a set of indices A={0..4} and B={5,6} for the two datasets. + The total batch size will be exactly 4. The first three elements are in that batch are + elements of A, the last an element of B. + An example batch could be ``[3, 1, 0, 6]``. + + Since we always provide a batch size of exactly ` sum(batch_sizes)``, we drop the last + batch. + + Args: dataset_lengths: The length for the different datasets. batch_sizes: Batch sizes used for specific datasets. @@ -200,7 +216,7 @@ def __init__( data_start_idx = data_end_idx self.length = ( min(num_batches) - if complete_dataset_iteration is None + if self.complete_dataset_iteration is None else num_batches[self.complete_dataset_iteration] ) From 615dc1677fdc89addc5a96d15a250fca6b3d05d7 Mon Sep 17 00:00:00 2001 From: Martin Wistuba Date: Fri, 22 Sep 2023 10:31:37 +0200 Subject: [PATCH 13/13] test distributed behavior --- test/renate/utils/test_pytorch.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/test/renate/utils/test_pytorch.py b/test/renate/utils/test_pytorch.py index 73f02ccd..ab1cfa5b 100644 --- a/test/renate/utils/test_pytorch.py +++ b/test/renate/utils/test_pytorch.py @@ -4,7 +4,7 @@ import pytest import torch import torchvision -from torch.utils.data import TensorDataset +from torch.utils.data import Sampler, TensorDataset from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule from renate.benchmark.scenarios import ClassIncrementalScenario @@ -170,3 +170,21 @@ def test_concat_random_sampler(complete_dataset_iteration, expected_batches): assert all([20 <= s < 40 for s in sample[3:]]) num_batches += 1 assert num_batches == expected_batches + + +def test_concat_random_sampler_distributed(): + """Tests behavior in case of distributed computing.""" + mock_sampler = Sampler(None) + mock_sampler.rank = 1 + mock_sampler.num_replicas = 2 + expected_batches = 2 + sampler = ConcatRandomSampler( + dataset_lengths=[16, 10], batch_sizes=[2, 2], sampler=mock_sampler + ) + assert len(sampler) == expected_batches + num_batches = 0 + for sample in sampler: + assert all([7 < s < 16 for s in sample[:2]]) + assert all([21 <= s < 26 for s in sample[2:]]) + num_batches += 1 + assert num_batches == expected_batches