Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#189 Add unpatch for dropout and consistent dropout #194

Merged
merged 5 commits into from
Mar 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions baal/bayesian/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Callable
from torch import nn


def replace_layers_in_module(module: nn.Module, mapping_fn: Callable) -> bool:
"""
Recursively iterate over the children of a module and replace them according to `mapping_fn`.

Returns:
True if a layer has been changed.
"""
changed = False
for name, child in module.named_children():
new_module = mapping_fn(child)

if new_module is not None:
changed = True
module.add_module(name, new_module)

# recursively apply to child
changed |= replace_layers_in_module(child, mapping_fn)
return changed
62 changes: 42 additions & 20 deletions baal/bayesian/consistent_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch.nn import functional as F
from torch.nn.modules.dropout import _DropoutNd

from baal.bayesian.common import replace_layers_in_module


class ConsistentDropout(_DropoutNd):
"""
Expand Down Expand Up @@ -115,32 +117,50 @@ def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Modu
"""
if not inplace:
module = copy.deepcopy(module)
changed = _patch_dropout_layers(module)
changed = replace_layers_in_module(module, _consistent_dropout_mapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module


def _patch_dropout_layers(module: torch.nn.Module) -> bool:
"""
Recursively iterate over the children of a module and replace them if
they are a dropout layer. This function operates in-place.
def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
"""Replace ConsistentDropout layers in a model with Dropout layers.

Args:
module (torch.nn.Module):
The module in which you would like to replace dropout layers.
inplace (bool, optional):
Whether to modify the module in place or return a copy of the module.

Returns:
torch.nn.Module
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
changed = False
for name, child in module.named_children():
new_module: Optional[nn.Module] = None
if isinstance(child, torch.nn.Dropout):
new_module = ConsistentDropout(p=child.p)
elif isinstance(child, torch.nn.Dropout2d):
new_module = ConsistentDropout2d(p=child.p)
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _consistent_dropout_unmapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module


if new_module is not None:
changed = True
module.add_module(name, new_module)
def _consistent_dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
if isinstance(module, torch.nn.Dropout):
new_module = ConsistentDropout(p=module.p)
elif isinstance(module, torch.nn.Dropout2d):
new_module = ConsistentDropout2d(p=module.p)
return new_module

# recursively apply to child
changed |= _patch_dropout_layers(child)
return changed

def _consistent_dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
if isinstance(module, ConsistentDropout):
new_module = torch.nn.Dropout(p=module.p)
elif isinstance(module, ConsistentDropout2d):
new_module = torch.nn.Dropout2d(p=module.p)
return new_module


class MCConsistentDropoutModule(torch.nn.Module):
Expand All @@ -152,8 +172,10 @@ def __init__(self, module: torch.nn.Module):
A fully specified neural network.
"""
super().__init__()
self.parent_module = module
_patch_dropout_layers(self.parent_module)
self.parent_module = patch_module(module)

def forward(self, *args, **kwargs):
return self.parent_module.forward(*args, **kwargs)

def unpatch(self) -> torch.nn.Module:
return unpatch_module(self.parent_module)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will make sure that self.parent_module is always patched but at every step the user can ask for an unpatched version right? its very useful just wanna make sure that if you call unpatch the internal module remains patched.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

66 changes: 41 additions & 25 deletions baal/bayesian/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from torch.nn import functional as F
from torch.nn.modules.dropout import _DropoutNd

from baal.bayesian.common import replace_layers_in_module


class Dropout(_DropoutNd):
r"""Randomly zeroes some of the elements of the input
Expand Down Expand Up @@ -85,53 +87,65 @@ def forward(self, input):


def patch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
"""Replace dropout layers in a model with MC Dropout layers.
"""Replace dropout layers in a model with MCDropout layers.

Args:
module (torch.nn.Module):
The module in which you would like to replace dropout layers.
inplace (bool, optional):
Whether to modify the module in place or return a copy of the module.

Raises:
UserWarning if no layer is modified.

Returns:
torch.nn.Module
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
if not inplace:
module = copy.deepcopy(module)
changed = _patch_dropout_layers(module)
changed = replace_layers_in_module(module, _dropout_mapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module


def _patch_dropout_layers(module: torch.nn.Module) -> bool:
"""
Recursively iterate over the children of a module and replace them if
they are a dropout layer. This function operates in-place.
def unpatch_module(module: torch.nn.Module, inplace: bool = True) -> torch.nn.Module:
"""Replace MCDropout layers in a model with Dropout layers.

Args:
module (torch.nn.Module):
The module in which you would like to replace dropout layers.
inplace (bool, optional):
Whether to modify the module in place or return a copy of the module.

Returns:
Flag indicating if a layer was modified.
torch.nn.Module
The modified module, which is either the same object as you passed in
(if inplace = True) or a copy of that object.
"""
changed = False
for name, child in module.named_children():
new_module: Optional[nn.Module] = None
if isinstance(child, torch.nn.Dropout):
new_module = Dropout(p=child.p, inplace=child.inplace)
elif isinstance(child, torch.nn.Dropout2d):
new_module = Dropout2d(p=child.p, inplace=child.inplace)
if not inplace:
module = copy.deepcopy(module)
changed = replace_layers_in_module(module, _dropout_unmapping_fn)
if not changed:
warnings.warn("No layer was modified by patch_module!", UserWarning)
return module


if new_module is not None:
changed = True
module.add_module(name, new_module)
def _dropout_mapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
if isinstance(module, torch.nn.Dropout):
new_module = Dropout(p=module.p, inplace=module.inplace)
elif isinstance(module, torch.nn.Dropout2d):
new_module = Dropout2d(p=module.p, inplace=module.inplace)
return new_module

# recursively apply to child
changed |= _patch_dropout_layers(child)
return changed

def _dropout_unmapping_fn(module: torch.nn.Module) -> Optional[nn.Module]:
new_module: Optional[nn.Module] = None
if isinstance(module, Dropout):
new_module = torch.nn.Dropout(p=module.p, inplace=module.inplace)
elif isinstance(module, Dropout2d):
new_module = torch.nn.Dropout2d(p=module.p, inplace=module.inplace)
return new_module


class MCDropoutModule(torch.nn.Module):
Expand All @@ -143,8 +157,10 @@ def __init__(self, module: torch.nn.Module):
A fully specified neural network.
"""
super().__init__()
self.parent_module = module
_patch_dropout_layers(self.parent_module)
self.parent_module = patch_module(module)

def forward(self, *args, **kwargs):
return self.parent_module(*args, **kwargs)

def unpatch(self) -> torch.nn.Module:
return unpatch_module(self.parent_module)
2 changes: 1 addition & 1 deletion baal/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def make_animation_from_data(
return frames


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
from sklearn.datasets import make_classification
import imageio

Expand Down
14 changes: 14 additions & 0 deletions tests/active/active_loop_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pickle
import warnings

import numpy as np
import pytest
Expand Down Expand Up @@ -126,5 +127,18 @@ def test_file_saving(tmpdir):
assert (data['dataset']['labelled'] != dataset.labelled).sum() == 10


def test_deprecation():
heur = heuristics.BALD()
ds = MyDataset()
dataset = ActiveLearningDataset(ds, make_unlabelled=lambda x: -1)
with warnings.catch_warnings(record=True) as w:
active_loop = ActiveLearningLoop(dataset,
get_probs_iter,
heur,
ndata_to_label=10,
dummy_param=1)
assert issubclass(w[-1].category, DeprecationWarning)
assert "ndata_to_label" in str(w[-1].message)

if __name__ == '__main__':
pytest.main()
64 changes: 64 additions & 0 deletions tests/bayesian/common_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import pytest
from torch import nn

from baal.bayesian.common import replace_layers_in_module


@pytest.fixture
def a_model_deep():
return nn.Sequential(
nn.Linear(32, 32),
nn.Sequential(
nn.Linear(32, 3),
nn.ReLU(),
nn.Linear(10, 3),
nn.ReLU(),
nn.Linear(3, 3)
))


@pytest.fixture
def a_model():
return nn.Sequential(
nn.Linear(32, 3),
nn.ReLU(),
nn.Linear(10, 3),
nn.ReLU(),
nn.Linear(3, 3)
)


def test_replace_layers_in_module_swap_all_relu(a_model):
mapping = lambda mod: None if not isinstance(mod, nn.ReLU) else nn.Identity()
changed = replace_layers_in_module(a_model, mapping)
assert changed
assert not any(isinstance(m, nn.ReLU) for m in a_model.modules())
assert any(isinstance(m, nn.Identity) for m in a_model.modules())


def test_replace_layers_in_module_swap_all_relu_deep(a_model_deep):
mapping = lambda mod: None if not isinstance(mod, nn.ReLU) else nn.Identity()
changed = replace_layers_in_module(a_model_deep, mapping)
assert changed
assert not any(isinstance(m, nn.ReLU) for m in a_model_deep.modules())
assert any(isinstance(m, nn.Identity) for m in a_model_deep.modules())


def test_replace_layers_in_module_swap_no_relu_deep(a_model_deep):
mapping = lambda mod: None if not isinstance(mod, nn.ReLU6) else nn.Identity()
changed = replace_layers_in_module(a_model_deep, mapping)
assert not changed
assert any(isinstance(m, nn.ReLU) for m in a_model_deep.modules())
assert not any(isinstance(m, nn.Identity) for m in a_model_deep.modules())

def test_replace_layers_in_module_swap_no_relu_deep(a_model):
mapping = lambda mod: None if not isinstance(mod, nn.ReLU6) else nn.Identity()
changed = replace_layers_in_module(a_model, mapping)
assert not changed
assert any(isinstance(m, nn.ReLU) for m in a_model.modules())
assert not any(isinstance(m, nn.Identity) for m in a_model.modules())



if __name__ == '__main__':
pytest.main()
10 changes: 10 additions & 0 deletions tests/bayesian/consistent_dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,16 @@ def test_module_class_replaces_dropout_layers(a_model_with_dropout):
for _ in range(10)
)

# Check that unpatch works
module = test_mc_module.unpatch()
module.eval()
with torch.no_grad():
assert all(
torch.allclose(module(dummy_input), module(dummy_input))
for _ in range(10)
)
assert not any(isinstance(mod, baal.bayesian.consistent_dropout.ConsistentDropout) for mod in module.modules())

@pytest.mark.parametrize("inplace", (True, False))
def test_patch_module_raise_warnings(inplace):

Expand Down
18 changes: 16 additions & 2 deletions tests/bayesian/dropconnect_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
import torch

from baal.bayesian.weight_drop import patch_module, WeightDropLinear
from baal.bayesian.weight_drop import patch_module, WeightDropLinear, MCDropoutConnectModule


class SimpleModel(torch.nn.Module):
Expand Down Expand Up @@ -70,7 +70,21 @@ def test_patch_module_replaces_all_dropout_layers(inplace):
# objects should be the same if inplace is True and not otherwise:
assert (mc_test_module is test_module) == inplace
assert not any(
module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout)
module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout)
)
assert any(
isinstance(module, WeightDropLinear)
for module in mc_test_module.modules()
)


def test_mcdropconnect_replaces_all_dropout_layers_module():
test_module = SimpleModel()

mc_test_module = MCDropoutConnectModule(test_module, layers=['Conv2d', 'Linear', 'LSTM', 'GRU'])

assert not any(
module.p != 0 for module in mc_test_module.modules() if isinstance(module, torch.nn.Dropout)
)
assert any(
isinstance(module, WeightDropLinear)
Expand Down
6 changes: 6 additions & 0 deletions tests/bayesian/dropout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,11 @@ def test_module_class_replaces_dropout_layers(a_model_with_dropout):
)


# Check that unpatch works
module = test_mc_module.unpatch()
assert not any(isinstance(mod, baal.bayesian.dropout.Dropout) for mod in module.modules())



if __name__ == '__main__':
pytest.main()