Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Offline-ER for NestedTensors #336

Merged
merged 2 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/renate/updaters/experimental/offline_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from renate.types import NestedTensors
from renate.updaters.learner import ReplayLearner
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.pytorch import move_tensors_to_device
from renate.utils.pytorch import (
cat_nested_tensors,
get_length_nested_tensors,
move_tensors_to_device,
)


class OfflineExperienceReplayLearner(ReplayLearner):
Expand Down Expand Up @@ -96,9 +100,7 @@ def on_model_update_end(self) -> None:
self._num_points_previous_tasks += self._num_points_current_task
self._num_points_current_task = -1

def training_step(
self, batch: Dict[str, Tuple[NestedTensors, torch.Tensor]], batch_idx: int
) -> STEP_OUTPUT:
def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int) -> STEP_OUTPUT:
"""PyTorch Lightning function to return the training loss."""
if self._loss_weight_new_data is None:
alpha = self._num_points_current_task / (
Expand All @@ -107,13 +109,13 @@ def training_step(
else:
alpha = self._loss_weight_new_data
inputs, targets = batch["current_task"]
device = inputs.device
batch_size_current = inputs.shape[0]
device = next(self.parameters()).device
batch_size_current = get_length_nested_tensors(inputs)
batch_size_mem = 0
if "memory" in batch:
(inputs_mem, targets_mem), _ = batch["memory"]
batch_size_mem = inputs_mem.shape[0]
inputs = torch.cat((inputs, inputs_mem), 0)
batch_size_mem = get_length_nested_tensors(inputs_mem)
inputs = cat_nested_tensors((inputs, inputs_mem), 0)
targets = torch.cat((targets, targets_mem), 0)
outputs = self(inputs)
loss = self._loss_fn(outputs, targets)
Expand Down
39 changes: 38 additions & 1 deletion src/renate/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
import math
from typing import List, Optional
from typing import List, Optional, Tuple, Union

import torch
from torch.utils.data import Dataset, random_split
Expand Down Expand Up @@ -87,3 +87,40 @@ def move_tensors_to_device(tensors: NestedTensors, device: torch.device) -> Nest
"Expected `tensors` to be a nested structure of tensors, tuples, list and dict; "
f"discovered {type(tensors)}."
)


def get_length_nested_tensors(batch: NestedTensors) -> torch.Size:
"""Given a NestedTensor, return its length.

Assumes that the first axis in each element is the same.
"""
if isinstance(batch, torch.Tensor):
return batch.shape[0]
if isinstance(batch, tuple):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we assuming all the tensors in the tuple have the same shape? or that only the first one actually contains data? I think it's important to make this clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, technically they could be of different shapes. I've renamed the function to reflect that it now returns the first dim only and the docstring to say that we expect the first dim to match.

return batch[0].shape[0]
if isinstance(batch, dict):
return batch[next(iter(batch.keys()))].shape[0]


def cat_nested_tensors(
nested_tensors: Union[Tuple[NestedTensors], List[NestedTensors]], axis: int = 0
) -> NestedTensors:
"""Concatenates the two NestedTensors.

Equivalent of PyTorch's ``cat`` function for ``NestedTensors``.

Args:
nested_tensors: Tensors to be concatenated.
axis: Concatenation axis.
"""
if isinstance(nested_tensors[0], torch.Tensor):
return torch.cat(nested_tensors, axis)
if isinstance(nested_tensors[0], tuple):
return tuple(
cat_nested_tensors(nested_tensor, axis) for nested_tensor in zip(*nested_tensors)
)
if isinstance(nested_tensors[0], dict):
return {
key: cat_nested_tensors([nested_tensor[key] for nested_tensor in nested_tensors], axis)
for key in nested_tensors[0]
}
43 changes: 42 additions & 1 deletion test/renate/utils/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data import TensorDataset

from renate.utils import pytorch
from renate.utils.pytorch import randomly_split_data
from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors, randomly_split_data


@pytest.mark.parametrize("model", [torchvision.models.resnet18(pretrained=True)])
Expand Down Expand Up @@ -61,3 +61,44 @@ def test_random_splitting_sample_split_with_same_random_seed():
for i in range(5):
assert torch.equal(d_1_split_1[i][0], d_2_split_1[i][0])
assert torch.equal(d_1_split_2[i][0], d_2_split_2[i][0])


def test_get_length_nested_tensors():
expected_batch_size = 10
t = torch.zeros(expected_batch_size)
assert get_length_nested_tensors(t) == expected_batch_size
tuple_tensor = (t, t)
assert get_length_nested_tensors(tuple_tensor) == expected_batch_size
dict_tensor = {"k1": t, "k2": t}
assert get_length_nested_tensors(dict_tensor) == expected_batch_size


def test_cat_nested_tensors():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to test also the behavior in case of failure (e.g., shape mismatch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a test

tensor_dim = 2
first_dim_ones = 8
zeros = torch.zeros((2, tensor_dim))
ones = torch.ones((first_dim_ones, tensor_dim))
result = cat_nested_tensors((zeros, ones))
assert get_length_nested_tensors(result) == 10
assert result.mean() == 0.8
tuple_tensor = (zeros, ones)
result = cat_nested_tensors((tuple_tensor, tuple_tensor))
assert get_length_nested_tensors(result) == 4
assert result[0].sum() == 0
assert result[1].sum() == 2 * first_dim_ones * tensor_dim
dict_tensor = {"zeros": zeros, "ones": ones}
result = cat_nested_tensors((dict_tensor, dict_tensor))
assert get_length_nested_tensors(result) == 4
assert result["zeros"].sum() == 0
assert result["ones"].sum() == 2 * first_dim_ones * tensor_dim


def test_cat_nested_tensors_wrong_shape():
tensor1 = torch.zeros((2, 2))
tensor2 = torch.zeros((2, 3))
with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"):
cat_nested_tensors((tensor1, tensor2))
with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"):
cat_nested_tensors(((tensor1, tensor1), (tensor1, tensor2)))
with pytest.raises(RuntimeError, match=r"Sizes of tensors must match except in dimension 0.*"):
cat_nested_tensors(({"k1": tensor1, "k2": tensor1}, {"k1": tensor1, "k2": tensor2}))