Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Offline-ER to work with collate_fn #390

Merged
merged 14 commits into from
Sep 22, 2023
4 changes: 2 additions & 2 deletions src/renate/memory/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import copy
from collections import defaultdict
from typing import Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Tuple

import torch
from torch.utils.data import Dataset
Expand Down Expand Up @@ -67,7 +67,7 @@ def __len__(self) -> int:
"""Returns the current length of the buffer."""
return len(self._indices)

def __getitem__(self, idx: int) -> NestedTensors:
def __getitem__(self, idx: int) -> Tuple[NestedTensors, Dict[str, Any]]:
"""Reads the item at index `idx` of the buffer."""
i, j = self._indices[idx]
data = self._datasets[i][j]
Expand Down
57 changes: 31 additions & 26 deletions src/renate/updaters/experimental/offline_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
import torch
import torchmetrics
from pytorch_lightning.loggers.logger import Logger
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.types import STEP_OUTPUT
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import ConcatDataset, DataLoader, Dataset

from renate import defaults
from renate.memory import ReservoirBuffer
from renate.models import RenateModule
from renate.types import NestedTensors
from renate.updaters.learner import ReplayLearner
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.misc import maybe_populate_mask_and_ignore_logits
from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors
from renate.utils.pytorch import ConcatRandomSampler


class OfflineExperienceReplayLearner(ReplayLearner):
Expand Down Expand Up @@ -71,32 +71,39 @@ def on_model_update_start(
self._num_points_current_task = len(train_dataset)

def train_dataloader(self) -> DataLoader:
loaders = {}
if len(self._memory_buffer) > self._memory_batch_size:
loaders["current_task"] = super().train_dataloader()
loaders["memory"] = DataLoader(
dataset=self._memory_buffer,
batch_size=self._memory_batch_size,
drop_last=True,
shuffle=True,
train_buffer = ReservoirBuffer(
max_size=self._num_points_current_task,
seed=0,
transform=self._train_transform,
target_transform=self._train_target_transform,
)
train_buffer.update(self._train_dataset)
return DataLoader(
dataset=ConcatDataset([train_buffer, self._memory_buffer]),
generator=self._rng,
pin_memory=True,
collate_fn=self._train_collate_fn,
batch_sampler=ConcatRandomSampler(
[self._num_points_current_task, len(self._memory_buffer)],
[self._batch_size, self._memory_batch_size],
0,
generator=self._rng,
),
)
else:
batch_size = self._batch_size
self._batch_size += self._memory_batch_size
loaders["current_task"] = super().train_dataloader()
self._batch_size = batch_size
return CombinedLoader(loaders, mode="max_size_cycle")
self._batch_size += self._memory_batch_size
self._memory_batch_size = 0
return super().train_dataloader()

def on_model_update_end(self) -> None:
"""Called right before a model update terminates."""
self._memory_buffer.update(self._train_dataset)
self._num_points_previous_tasks += self._num_points_current_task
self._num_points_current_task = -1

def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT:
def training_step(
self, batch: Tuple[NestedTensors, Dict[str, Any]], batch_idx: int
) -> STEP_OUTPUT:
"""PyTorch Lightning function to return the training loss."""
if self._loss_weight_new_data is None:
alpha = self._num_points_current_task / (
Expand All @@ -105,21 +112,19 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int)
else:
alpha = self._loss_weight_new_data
alpha = torch.tensor(alpha, device=next(self.parameters()).device)
inputs, targets = batch["current_task"]
batch_size_current = get_length_nested_tensors(inputs)
if "memory" in batch:
(inputs_mem, targets_mem), _ = batch["memory"]
inputs = cat_nested_tensors((inputs, inputs_mem), 0)
targets = torch.cat((targets, targets_mem), 0)
if self._memory_batch_size:
(inputs, targets), _ = batch
else:
inputs, targets = batch
outputs = self(inputs)

outputs, self._class_mask = maybe_populate_mask_and_ignore_logits(
self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs
)
loss = self._loss_fn(outputs, targets)
if "memory" in batch:
loss_current = loss[:batch_size_current].mean()
loss_memory = loss[batch_size_current:].mean()
if self._memory_batch_size:
loss_current = loss[: self._batch_size].mean()
loss_memory = loss[self._batch_size :].mean()
self._loss_collections["train_losses"]["base_loss"](loss_current)
self._loss_collections["train_losses"]["memory_loss"](loss_memory)
loss = alpha * loss_current + (1.0 - alpha) * loss_memory
Expand Down
2 changes: 1 addition & 1 deletion src/renate/updaters/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def on_model_update_start(
) -> None:
self._train_dataset = train_dataset
self._val_dataset = val_dataset
self.val_enabled = val_dataset is not None and len(val_dataset)
self.val_enabled = val_dataset is not None and len(val_dataset) > 0
self._train_collate_fn = train_dataset_collate_fn
self._val_collate_fn = val_dataset_collate_fn
self._task_id = task_id
Expand Down
97 changes: 93 additions & 4 deletions src/renate/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import List, Optional, Set, Tuple, Union
from typing import Any, Iterator, List, Optional, Set, Tuple, Union

import torch
from torch.utils.data import Dataset, random_split
from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split
from transformers import BatchEncoding

from renate import defaults
Expand All @@ -20,11 +20,11 @@ def reinitialize_model_parameters(model: torch.nn.Module) -> None:
implementations of exotic layers. A warning is logged for modules that do not implement
`reset_parameters()`.

The actual logic of renitializing parameters depends on the type of layer. It may affect the
The actual logic of reinitializing parameters depends on the type of layer. It may affect the
module's buffers (non-trainable parameters, e.g., batch norm stats) as well.

Args:
model: The model to be re-initialized.
model: The model to be reinitialized.
"""
for module in model.modules():
# Skip modules without any parameters of their own.
Expand Down Expand Up @@ -156,3 +156,92 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int
valid_classes: A set of integers of valid classes.
"""
return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes]


class ConcatRandomSampler(Sampler[List[int]]):
"""Sampler for sampling batches from ConcatDatasets.

Each sampled batch is composed of batches of different BatchSamplers with the specified
batch sizes and ranges.

To clarify the behavior, we provide a little example.
``dataset_lengths = [5, 2]``
``batch_sizes = [3, 1]``

With this setting, we have a set of indices A={0..4} and B={5,6} for the two datasets.
The total batch size will be exactly 4. The first three elements are in that batch are
elements of A, the last an element of B.
An example batch could be ``[3, 1, 0, 6]``.

Since we always provide a batch size of exactly ` sum(batch_sizes)``, we drop the last
batch.


Args:
dataset_lengths: The length for the different datasets.
batch_sizes: Batch sizes used for specific datasets.
complete_dataset_iteration: Provide an index to indicate over which dataset to fully
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly rename?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestions?

iterate. By default, stops whenever iteration is complete for any dataset.
generator: Generator used in sampling.
sampler: Lightning automatically passes a DistributedSamplerWrapper. Only used as an
indicator that we are in the distributed case.
"""

def __init__(
self,
dataset_lengths: List[int],
batch_sizes: List[int],
complete_dataset_iteration: Optional[int] = None,
generator: Any = None,
sampler: Sampler = None,
) -> None:
self.batch_sizes = batch_sizes
self.complete_dataset_iteration = complete_dataset_iteration
self.subset_samplers = []
data_start_idx = 0
num_batches = []
rank = 0 if sampler is None else sampler.rank
num_replicas = 1 if sampler is None else sampler.num_replicas
for dataset_length, batch_size in zip(dataset_lengths, batch_sizes):
data_end_idx = data_start_idx + dataset_length
start_idx = data_start_idx + round(dataset_length / num_replicas * rank)
end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1))
subset_sampler = BatchSampler(
SubsetRandomSampler(list(range(start_idx, end_idx)), generator),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BatchSampler of SubsetRandomSampler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchSampler creates batches, SubsetRandomSampler creates random ints from the provided list (List[int] vs int)

batch_size,
True,
)
self.subset_samplers.append(subset_sampler)
num_batches.append((end_idx - start_idx) // batch_size)
data_start_idx = data_end_idx
self.length = (
min(num_batches)
if self.complete_dataset_iteration is None
else num_batches[self.complete_dataset_iteration]
)

def __iter__(self) -> Iterator[List[int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments about the exact logic?

"""Creates a batch with groups of indices where each group corresponds to one dataset."""
if self.complete_dataset_iteration is None:
# Default case is iterating once over the shortest iterator. Works nicely with zip.
for samples in zip(*self.subset_samplers):
yield [j for i in samples for j in i]
else:
# Iterating over a specific iterator requires dealing with the length of other
# iterators.
iterators = [iter(sampler) for sampler in self.subset_samplers]
for s in iterators[self.complete_dataset_iteration]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this optimized? Nested for-loops for each batch seems like a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no nested loop for each batch. it is a single loop over each iterator. in case 1 this is hidden within zip but it also has a loop over each iterator and calls next.

samples = []
for i, iterator in enumerate(iterators):
if i != self.complete_dataset_iteration:
try:
samples += next(iterator)
except StopIteration:
iterators[i] = iter(self.subset_samplers[i])
wistuba marked this conversation as resolved.
Show resolved Hide resolved
samples += next(iterators[i])
else:
samples += s
yield samples

def __len__(self):
return self.length
4 changes: 2 additions & 2 deletions test/integration_tests/configs/suites/quick/offline-er.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"dataset": "cifar10.json",
"backend": "local",
"job_name": "class-incremental-mlp-offline-er",
"expected_accuracy_linux": [[0.6980000138282776, 0.546999990940094], [0.6514999866485596, 0.3725000023841858]],
"expected_accuracy_darwin": [[0.7315000295639038, 0.49000000953674316]]
"expected_accuracy_linux": [[0.7634999752044678, 0.40299999713897705], [0.6234999895095825, 0.3779999911785126]],
"expected_accuracy_darwin": [[0.7279999852180481, 0.4650000035762787]]
}
40 changes: 39 additions & 1 deletion test/renate/utils/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pytest
import torch
import torchvision
from torch.utils.data import TensorDataset
from torch.utils.data import Sampler, TensorDataset

from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule
from renate.benchmark.scenarios import ClassIncrementalScenario
from renate.memory.buffer import ReservoirBuffer
from renate.utils import pytorch
from renate.utils.pytorch import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a DistributedSampler to a test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a unit test for the distributed case instead

ConcatRandomSampler,
cat_nested_tensors,
complementary_indices,
get_length_nested_tensors,
Expand Down Expand Up @@ -150,3 +151,40 @@ def test_unique_classes(tmpdir, test_dataset):
buffer.update(ds, metadata)
predicted_unique = unique_classes(buffer)
assert predicted_unique == set(list(range(10)))


@pytest.mark.parametrize(
"complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For None batches is 2 because 20//8 = 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. it is identical to the [2, 2] case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a drop_last is implicit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. improved doc

)
def test_concat_random_sampler(complete_dataset_iteration, expected_batches):
sampler = ConcatRandomSampler(
dataset_lengths=[15, 5, 20],
batch_sizes=[2, 1, 8],
complete_dataset_iteration=complete_dataset_iteration,
)
assert len(sampler) == expected_batches
num_batches = 0
for sample in sampler:
assert all([s < 15 for s in sample[:2]])
assert all([15 <= s < 20 for s in sample[2:3]])
assert all([20 <= s < 40 for s in sample[3:]])
num_batches += 1
assert num_batches == expected_batches


def test_concat_random_sampler_distributed():
"""Tests behavior in case of distributed computing."""
mock_sampler = Sampler(None)
mock_sampler.rank = 1
mock_sampler.num_replicas = 2
expected_batches = 2
sampler = ConcatRandomSampler(
dataset_lengths=[16, 10], batch_sizes=[2, 2], sampler=mock_sampler
)
assert len(sampler) == expected_batches
num_batches = 0
for sample in sampler:
assert all([7 < s < 16 for s in sample[:2]])
assert all([21 <= s < 26 for s in sample[2:]])
num_batches += 1
assert num_batches == expected_batches
Loading