Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

SSL Online evaluator: save checkpoints without DDP wrapper #623

Merged
merged 12 commits into from
Jan 6, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ gets uploaded to AzureML, by skipping all test folders.
- ([#566](https://github.com/microsoft/InnerEye-DeepLearning/pull/566)) Update `hi-ml` dependency to `hi-ml-azure`.
- ([#591](https://github.com/microsoft/InnerEye-DeepLearning/pull/591)) Upgrade Pytorch Lightning to 1.5.0
- ([#572](https://github.com/microsoft/InnerEye-DeepLearning/pull/572)) Updated to new version of hi-ml package
- ([#623](https://github.com/microsoft/InnerEye-DeepLearning/pull/623)) Save checkpoints in SSLOnlineEvaluator without DDP wrapper code
- ([#617](https://github.com/microsoft/InnerEye-DeepLearning/pull/617)) Provide an easier way for LightningContainers to add callbacks.
- ([#596](https://github.com/microsoft/InnerEye-DeepLearning/pull/596)) Add `cudatoolkit=11.1` specification to environment.yml.
- ([#615](https://github.com/microsoft/InnerEye-DeepLearning/pull/615)) Minor changes to checkpoint download from AzureML.
Expand Down
40 changes: 29 additions & 11 deletions InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------

from typing import Any, Dict, List, Optional, OrderedDict, Set, Tuple, Union

import pytorch_lightning as pl
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
Expand All @@ -12,7 +14,6 @@
from torch.nn import SyncBatchNorm, functional as F
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from InnerEye.ML.SSL.utils import SSLDataModuleType, add_submodules_to_same_device
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve
Expand Down Expand Up @@ -49,30 +50,36 @@ def __init__(self,
Accuracy05()] \
if self.num_classes == 2 else [Accuracy05()]
self.class_weights = class_weights
self.evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
self.evaluator_state: Optional[OrderedDict] = None
self.optimizer_state: Optional[OrderedDict] = None

def _wrapped_evaluator(self) -> torch.nn.Module:
"""
Gets the evaluator model that is wrapped in DDP, or the evaluator model itself.
"""
if isinstance(self.evaluator, DistributedDataParallel):
return self.evaluator.module
else:
return self.evaluator

def on_save_checkpoint(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
checkpoint: Dict[str, Any]) -> Dict[str, Any]:
# Each callback gets its own state dictionary, that are fed back in during load
# When saving the evaluator, use the wrapped DDP module (otherwise the resulting checkpoint will depend
# on use of DDP or not).
return {
self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
self.EVALUATOR_STATE_NAME: self.evaluator.state_dict()
self.EVALUATOR_STATE_NAME: self._wrapped_evaluator().state_dict()
}

def on_load_checkpoint(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback_state: Dict[str, Any]) -> None:
self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME])
self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME])
self.optimizer_state = callback_state[self.OPTIMIZER_STATE_NAME]
self.evaluator_state = callback_state[self.EVALUATOR_STATE_NAME]

def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""
Expand All @@ -82,6 +89,10 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning
"""
for prefix, metrics in [("train", self.train_metrics), ("val", self.val_metrics)]:
add_submodules_to_same_device(pl_module, metrics, prefix=prefix)
self.evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.evaluator.to(pl_module.device)
if hasattr(trainer, "accelerator_connector"):
# This works with Lightning 1.3.8
Expand All @@ -98,6 +109,13 @@ def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.Lightning
else:
rank_zero_warn("This type of distributed accelerator is not supported. "
"The online evaluator will not synchronize across GPUs.")
self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)
if self.evaluator_state is not None:
self._wrapped_evaluator().load_state_dict(self.evaluator_state)
if self.optimizer_state is not None:
self.optimizer.load_state_dict(self.optimizer_state)

@staticmethod
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
Expand Down
25 changes: 13 additions & 12 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> No
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# Ensure that the parameters are really different initially
parameters2_before_training = list(callback2.evaluator.parameters())
assert not torch.allclose(parameters2_before_training[0], parameters1[0])
# Start a second training run with recovery
last_checkpoint = checkpoints.last_model_path
trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir),
Expand Down Expand Up @@ -363,19 +360,24 @@ def test_online_evaluator_not_distributed() -> None:
# Test the flag that the internal logic of on_pretrain_routine_start uses
assert hasattr(trainer, "_accelerator_connector")
assert not trainer._accelerator_connector.is_distributed
mock_module = mock.MagicMock(device=torch.device("cpu"))
callback.on_pretrain_routine_start(trainer, mock_module)
cpu = torch.device("cpu")
callback.on_pretrain_routine_start(trainer, mock.MagicMock(device=cpu))
assert isinstance(callback.evaluator, Module)
mock_ddp.assert_not_called()
# Check that the evaluator is on the GPU before making any changes
assert list(callback.evaluator.parameters())[0].device == cpu
# Check that the evaluator is really moved to the right device
gpu0 = torch.device("cuda:0")
callback.on_pretrain_routine_start(trainer, mock.MagicMock(device=gpu0))
assert list(callback.evaluator.parameters())[0].device == gpu0


@pytest.mark.gpu
def test_online_evaluator_distributed() -> None:
"""
Check if the online evaluator uses the DDP flag correctly when running distributed.
"""
mock_ddp_result = "mock_ddp_result"
mock_sync_result = "mock_sync_result"
mock_ddp_result = torch.nn.Linear(in_features=10, out_features=1)
mock_sync_result = torch.nn.Linear(in_features=20, out_features=2)
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SyncBatchNorm.convert_sync_batchnorm",
return_value=mock_sync_result) as mock_sync:
with mock.patch("InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.DistributedDataParallel",
Expand All @@ -388,16 +390,15 @@ def test_online_evaluator_distributed() -> None:
learning_rate=1e-5)

# Trainer with DDP
device = torch.device("cuda:0")
device = torch.device("cpu")
mock_module = mock.MagicMock(device=device)
trainer = Trainer(accelerator="ddp", gpus=2)
trainer = Trainer(strategy="ddp", num_processes=2)
# Test the two flags that the internal logic of on_pretrain_routine_start uses
assert trainer._accelerator_connector.is_distributed
assert trainer._accelerator_connector.use_ddp
original_evaluator = callback.evaluator
callback.on_pretrain_routine_start(trainer, mock_module)
# Check that SyncBatchNorm has been turned on
mock_sync.assert_called_once_with(original_evaluator)
mock_sync.assert_called_once()
# Check that the evaluator has been turned into a DDP object
# We still need to mock DDP here because the constructor relies on having a process group available
mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device])
Expand Down