-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Offline-ER to work with collate_fn
#390
Changes from all commits
e5f0d3e
542c855
c23014d
0f8803a
546b8eb
8cea500
0b07ac3
c233f95
8190e26
8c69db7
40ffc88
a3bf468
2746391
615dc16
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,10 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
import logging | ||
import math | ||
from typing import List, Optional, Set, Tuple, Union | ||
from typing import Any, Iterator, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
from torch.utils.data import Dataset, random_split | ||
from torch.utils.data import BatchSampler, Dataset, Sampler, SubsetRandomSampler, random_split | ||
from transformers import BatchEncoding | ||
|
||
from renate import defaults | ||
|
@@ -20,11 +20,11 @@ def reinitialize_model_parameters(model: torch.nn.Module) -> None: | |
implementations of exotic layers. A warning is logged for modules that do not implement | ||
`reset_parameters()`. | ||
|
||
The actual logic of renitializing parameters depends on the type of layer. It may affect the | ||
The actual logic of reinitializing parameters depends on the type of layer. It may affect the | ||
module's buffers (non-trainable parameters, e.g., batch norm stats) as well. | ||
|
||
Args: | ||
model: The model to be re-initialized. | ||
model: The model to be reinitialized. | ||
""" | ||
for module in model.modules(): | ||
# Skip modules without any parameters of their own. | ||
|
@@ -156,3 +156,92 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int | |
valid_classes: A set of integers of valid classes. | ||
""" | ||
return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes] | ||
|
||
|
||
class ConcatRandomSampler(Sampler[List[int]]): | ||
"""Sampler for sampling batches from ConcatDatasets. | ||
|
||
Each sampled batch is composed of batches of different BatchSamplers with the specified | ||
batch sizes and ranges. | ||
|
||
To clarify the behavior, we provide a little example. | ||
``dataset_lengths = [5, 2]`` | ||
``batch_sizes = [3, 1]`` | ||
|
||
With this setting, we have a set of indices A={0..4} and B={5,6} for the two datasets. | ||
The total batch size will be exactly 4. The first three elements are in that batch are | ||
elements of A, the last an element of B. | ||
An example batch could be ``[3, 1, 0, 6]``. | ||
|
||
Since we always provide a batch size of exactly ` sum(batch_sizes)``, we drop the last | ||
batch. | ||
|
||
|
||
Args: | ||
dataset_lengths: The length for the different datasets. | ||
batch_sizes: Batch sizes used for specific datasets. | ||
complete_dataset_iteration: Provide an index to indicate over which dataset to fully | ||
iterate. By default, stops whenever iteration is complete for any dataset. | ||
generator: Generator used in sampling. | ||
sampler: Lightning automatically passes a DistributedSamplerWrapper. Only used as an | ||
indicator that we are in the distributed case. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset_lengths: List[int], | ||
batch_sizes: List[int], | ||
complete_dataset_iteration: Optional[int] = None, | ||
generator: Any = None, | ||
sampler: Sampler = None, | ||
) -> None: | ||
self.batch_sizes = batch_sizes | ||
self.complete_dataset_iteration = complete_dataset_iteration | ||
self.subset_samplers = [] | ||
data_start_idx = 0 | ||
num_batches = [] | ||
rank = 0 if sampler is None else sampler.rank | ||
num_replicas = 1 if sampler is None else sampler.num_replicas | ||
for dataset_length, batch_size in zip(dataset_lengths, batch_sizes): | ||
data_end_idx = data_start_idx + dataset_length | ||
start_idx = data_start_idx + round(dataset_length / num_replicas * rank) | ||
end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1)) | ||
subset_sampler = BatchSampler( | ||
SubsetRandomSampler(list(range(start_idx, end_idx)), generator), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BatchSampler creates batches, SubsetRandomSampler creates random ints from the provided list (List[int] vs int) |
||
batch_size, | ||
True, | ||
) | ||
self.subset_samplers.append(subset_sampler) | ||
num_batches.append((end_idx - start_idx) // batch_size) | ||
data_start_idx = data_end_idx | ||
self.length = ( | ||
min(num_batches) | ||
if self.complete_dataset_iteration is None | ||
else num_batches[self.complete_dataset_iteration] | ||
) | ||
|
||
def __iter__(self) -> Iterator[List[int]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add comments about the exact logic? |
||
"""Creates a batch with groups of indices where each group corresponds to one dataset.""" | ||
if self.complete_dataset_iteration is None: | ||
# Default case is iterating once over the shortest iterator. Works nicely with zip. | ||
for samples in zip(*self.subset_samplers): | ||
yield [j for i in samples for j in i] | ||
else: | ||
# Iterating over a specific iterator requires dealing with the length of other | ||
# iterators. | ||
iterators = [iter(sampler) for sampler in self.subset_samplers] | ||
for s in iterators[self.complete_dataset_iteration]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this optimized? Nested for-loops for each batch seems like a lot. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is no nested loop for each batch. it is a single loop over each iterator. in case 1 this is hidden within |
||
samples = [] | ||
for i, iterator in enumerate(iterators): | ||
if i != self.complete_dataset_iteration: | ||
try: | ||
samples += next(iterator) | ||
except StopIteration: | ||
iterators[i] = iter(self.subset_samplers[i]) | ||
wistuba marked this conversation as resolved.
Show resolved
Hide resolved
|
||
samples += next(iterators[i]) | ||
else: | ||
samples += s | ||
yield samples | ||
|
||
def __len__(self): | ||
return self.length |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,14 @@ | |
import pytest | ||
import torch | ||
import torchvision | ||
from torch.utils.data import TensorDataset | ||
from torch.utils.data import Sampler, TensorDataset | ||
|
||
from renate.benchmark.datasets.vision_datasets import TorchVisionDataModule | ||
from renate.benchmark.scenarios import ClassIncrementalScenario | ||
from renate.memory.buffer import ReservoirBuffer | ||
from renate.utils import pytorch | ||
from renate.utils.pytorch import ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a unit test for the distributed case instead |
||
ConcatRandomSampler, | ||
cat_nested_tensors, | ||
complementary_indices, | ||
get_length_nested_tensors, | ||
|
@@ -150,3 +151,40 @@ def test_unique_classes(tmpdir, test_dataset): | |
buffer.update(ds, metadata) | ||
predicted_unique = unique_classes(buffer) | ||
assert predicted_unique == set(list(range(10))) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. it is identical to the [2, 2] case There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So a drop_last is implicit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. improved doc |
||
) | ||
def test_concat_random_sampler(complete_dataset_iteration, expected_batches): | ||
sampler = ConcatRandomSampler( | ||
dataset_lengths=[15, 5, 20], | ||
batch_sizes=[2, 1, 8], | ||
complete_dataset_iteration=complete_dataset_iteration, | ||
) | ||
assert len(sampler) == expected_batches | ||
num_batches = 0 | ||
for sample in sampler: | ||
assert all([s < 15 for s in sample[:2]]) | ||
assert all([15 <= s < 20 for s in sample[2:3]]) | ||
assert all([20 <= s < 40 for s in sample[3:]]) | ||
num_batches += 1 | ||
assert num_batches == expected_batches | ||
|
||
|
||
def test_concat_random_sampler_distributed(): | ||
"""Tests behavior in case of distributed computing.""" | ||
mock_sampler = Sampler(None) | ||
mock_sampler.rank = 1 | ||
mock_sampler.num_replicas = 2 | ||
expected_batches = 2 | ||
sampler = ConcatRandomSampler( | ||
dataset_lengths=[16, 10], batch_sizes=[2, 2], sampler=mock_sampler | ||
) | ||
assert len(sampler) == expected_batches | ||
num_batches = 0 | ||
for sample in sampler: | ||
assert all([7 < s < 16 for s in sample[:2]]) | ||
assert all([21 <= s < 26 for s in sample[2:]]) | ||
num_batches += 1 | ||
assert num_batches == expected_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly rename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestions?