Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enabling distributed training for SSL online evaluator #612

Merged
merged 16 commits into from
Dec 10, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ in inference-only runs when using lightning containers.
- ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model
weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs.
- ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided.
- ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training

### Removed

Expand Down
58 changes: 33 additions & 25 deletions InnerEye/ML/SSL/lightning_modules/ssl_online_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,24 @@
import torch
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.models.self_supervised.evaluator import SSLEvaluator
from pytorch_lightning.utilities import rank_zero_warn
from torch import Tensor as T
from health_ml.utils import log_on_epoch
from torch.nn import functional as F
from torch.nn import DataParallel, functional as F
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Metric

from InnerEye.ML.SSL.utils import SSLDataModuleType
from InnerEye.ML.lightning_metrics import Accuracy05, AreaUnderPrecisionRecallCurve, AreaUnderRocCurve
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode
from health_ml.utils import log_on_epoch

BatchType = Union[Dict[SSLDataModuleType, Any], Any]

OPTIMIZER_STATE_NAME = "evaluator_optimizer"
EVALUATOR_STATE_NAME = "evaluator_weights"


class SSLOnlineEvaluatorInnerEye(SSLOnlineEvaluator):
OPTIMIZER_STATE_NAME = "evaluator_optimizer"
EVALUATOR_STATE_NAME = "evaluator_weights"

def __init__(self,
learning_rate: float,
class_weights: Optional[torch.Tensor] = None,
Expand All @@ -47,11 +50,11 @@ def __init__(self,
Accuracy05()] \
if self.num_classes == 2 else [Accuracy05()]
self.class_weights = class_weights
self.non_linear_evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.optimizer = torch.optim.Adam(self.non_linear_evaluator.parameters(),
self.evaluator = SSLEvaluator(n_input=self.z_dim,
n_classes=self.num_classes,
p=self.drop_p,
n_hidden=self.hidden_dim)
self.optimizer = torch.optim.Adam(self.evaluator.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay)

Expand All @@ -61,24 +64,33 @@ def on_save_checkpoint(self,
checkpoint: Dict[str, Any]) -> Dict[str, Any]:
# Each callback gets its own state dictionary, that are fed back in during load
return {
OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
EVALUATOR_STATE_NAME: self.non_linear_evaluator.state_dict()
self.OPTIMIZER_STATE_NAME: self.optimizer.state_dict(),
self.EVALUATOR_STATE_NAME: self.evaluator.state_dict()
}

def on_load_checkpoint(self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
callback_state: Dict[str, Any]) -> None:
self.optimizer.load_state_dict(callback_state[OPTIMIZER_STATE_NAME])
self.non_linear_evaluator.load_state_dict(callback_state[EVALUATOR_STATE_NAME])
self.optimizer.load_state_dict(callback_state[self.OPTIMIZER_STATE_NAME])
self.evaluator.load_state_dict(callback_state[self.EVALUATOR_STATE_NAME])

def on_pretrain_routine_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
"""
Initializes modules and moves metrics and class weights to module device
"""
for metric in [*self.train_metrics, *self.val_metrics]:
metric.to(device=pl_module.device) # type: ignore
self.non_linear_evaluator.to(pl_module.device)
self.evaluator.to(pl_module.device)
accelerator = trainer.accelerator_connector
if accelerator.is_distributed:
if accelerator.use_ddp:
self.evaluator = DistributedDataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
elif accelerator.use_dp:
self.evaluator = DataParallel(self.evaluator, device_ids=[pl_module.device]) # type: ignore
else:
rank_zero_warn("This type of distributed accelerator is not supported. "
"The online evaluator will not synchronize across GPUs.")

@staticmethod
def to_device(batch: Any, device: Union[str, torch.device]) -> Tuple[T, T]:
Expand Down Expand Up @@ -108,7 +120,7 @@ def shared_step(self, batch: BatchType, pl_module: pl.LightningModule, is_traini
representations = representations.detach()

# Run the linear-head with SSL embeddings.
mlp_preds = self.non_linear_evaluator(representations)
mlp_preds = self.evaluator(representations)
weights = None if self.class_weights is None else self.class_weights.to(device=pl_module.device)
mlp_loss = F.cross_entropy(mlp_preds, y, weight=weights)

Expand All @@ -133,15 +145,11 @@ def on_validation_batch_end(self, trainer: pl.Trainer,
ids_linear_head = tuple(batch[SSLDataModuleType.LINEAR_HEAD][0].tolist())
if ids_linear_head not in self.visited_ids:
self.visited_ids.add(ids_linear_head)
# Put the online evaluator into "eval" mode
old_mode = self.non_linear_evaluator.training
self.non_linear_evaluator.eval()
loss = self.shared_step(batch, pl_module, is_training=False)
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
for metric in self.val_metrics:
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)
# Put the online evaluator back into the state (eval or train) that it was before calling this method
self.non_linear_evaluator.train(old_mode)
with set_model_to_eval_mode(self.evaluator):
loss = self.shared_step(batch, pl_module, is_training=False)
log_on_epoch(pl_module, 'ssl_online_evaluator/val/loss', loss)
for metric in self.val_metrics:
log_on_epoch(pl_module, f"ssl_online_evaluator/val/{metric.name}", metric)

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) -> None: # type: ignore
"""
Expand Down
16 changes: 15 additions & 1 deletion InnerEye/ML/utils/layer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Iterable, Sized, Tuple, Union
from contextlib import contextmanager
from typing import Generator, Iterable, Sized, Tuple, Union

import torch
from torch.nn import init
Expand Down Expand Up @@ -90,3 +91,16 @@ def upsample_size(down: int) -> int:
upsample_size(downsampling_factor[1]), # type: ignore
upsample_size(downsampling_factor[2])) # type: ignore
return upsampling_kernel_size


@contextmanager
def set_model_to_eval_mode(model: torch.nn.Module) -> Generator:
"""
Puts the given torch model into eval mode. At the end of the context, resets the state of the training flag to
what is was before the call.
:param model: The model to modify.
"""
old_mode = model.training
model.eval()
yield
model.train(old_mode)
11 changes: 4 additions & 7 deletions InnerEye/ML/visualizers/model_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from InnerEye.Common.fixed_paths import DEFAULT_MODEL_SUMMARIES_DIR_PATH
from InnerEye.ML.utils.device_aware_module import DeviceAwareModule
from InnerEye.ML.utils.ml_util import RandomStateSnapshot
from InnerEye.ML.utils.layer_util import set_model_to_eval_mode


@dataclass
Expand Down Expand Up @@ -217,15 +218,11 @@ def forward_preserve_state(module: DeviceAwareModule, inputs: List[torch.Tensor]
inputs = [input_tensor.cuda() for input_tensor in inputs]

# collect the current state of the model
is_train = module.training
module_state = RandomStateSnapshot.snapshot_random_state()

# set the model in evaluation mode and perform a forward pass
module.eval()
with torch.no_grad():
output = module.forward(*inputs)
if is_train:
module.train()
with set_model_to_eval_mode(module):
with torch.no_grad():
output = module.forward(*inputs)

# restore the seed for torch and numpy
module_state.restore_random_state()
Expand Down
105 changes: 101 additions & 4 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,27 @@
import pytest
import torch
from pl_bolts.models.self_supervised.resnets import ResNet
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.nn import Module
from torch.nn.parallel import DistributedDataParallel
from torch.optim.lr_scheduler import _LRScheduler

from InnerEye.Common import fixed_paths
from InnerEye.Common.common_util import is_windows
from InnerEye.Common.fixed_paths import repository_root_directory
from InnerEye.Common.fixed_paths_for_tests import full_ml_test_data_path
from InnerEye.Common.output_directories import OutputFolderForTests
from InnerEye.ML.SSL.lightning_containers.ssl_container import EncoderName, SSLDatasetName
from InnerEye.ML.SSL.lightning_modules.byol.byol_module import BYOLInnerEye
from InnerEye.ML.SSL.lightning_modules.simclr_module import SimCLRInnerEye
from InnerEye.ML.SSL.lightning_modules.ssl_classifier_module import SSLClassifier
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import EVALUATOR_STATE_NAME, OPTIMIZER_STATE_NAME, \
SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye
from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType
from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX
from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier
from InnerEye.ML.runner import Runner
from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel
from Tests.ML.utils.test_io_util import write_test_dicom

path_to_test_dataset = full_ml_test_data_path("cxr_test_dataset")
Expand Down Expand Up @@ -133,8 +138,8 @@ def test_innereye_ssl_container_cifar10_resnet_simclr() -> None:
assert "callbacks" in checkpoint
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
assert OPTIMIZER_STATE_NAME in callback_state
assert EVALUATOR_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state

# Now run the actual SSL classifier off the stored checkpoint
args = common_test_args + ["--model=SSLClassifierCIFAR", f"--local_ssl_weights_path={checkpoint_path}"]
Expand Down Expand Up @@ -268,3 +273,95 @@ def test_simclr_lr_scheduler() -> None:
assert lr[i] < lr[i + 1], f"Not strictly monotonically increasing at index {i}"
for i in range(highest_lr, len(lr) - 1):
assert lr[i] > lr[i + 1], f"Not strictly monotonically decreasing at index {i}"


def test_online_evaluator_recovery(test_output_dirs: OutputFolderForTests) -> None:
"""
Test checkpoint recovery for the online evaluator in an end-to-end training run.
"""
container = DummyContainerWithModel()
model = container.create_model()
data = container.get_data_module()
checkpoint_folder = test_output_dirs.create_file_or_folder_path("checkpoints")
checkpoint_folder.mkdir(exist_ok=True)
checkpoints = ModelCheckpoint(dirpath=checkpoint_folder,
every_n_val_epochs=1,
save_last=True)
# Create a first callback, that will be used in training.
callback1 = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# To simplify the test setup, do not run any actual training (this would require complicated dataset with a
# combined loader)
with mock.patch(
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye.on_train_batch_end",
return_value=None) as mock_train:
with mock.patch(
"InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator.SSLOnlineEvaluatorInnerEye"
".on_validation_batch_end",
return_value=None):
trainer = Trainer(default_root_dir=str(test_output_dirs.root_dir),
callbacks=[checkpoints, callback1],
max_epochs=10)
trainer.fit(model, datamodule=data)
# Check that the callback was actually used
mock_train.assert_called()
# Now read out the parameters of the callback.
# We will then run a second training job, with a new callback object, that will be initialized randomly,
# and should have different parameters initially. After checkpoint recovery, it should have exactly the
# same parameters as the first callback.
parameters1 = list(callback1.evaluator.parameters())
callback2 = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
# Ensure that the parameters are really different initially
parameters2_before_training = list(callback2.evaluator.parameters())
assert not torch.allclose(parameters2_before_training[0], parameters1[0])
# Start a second training run with recovery
last_checkpoint = checkpoints.last_model_path
trainer2 = Trainer(default_root_dir=str(test_output_dirs.root_dir),
callbacks=[callback2],
max_epochs=20,
resume_from_checkpoint=last_checkpoint)
trainer2.fit(model, datamodule=data)
# Read the parameters and check if they are the same as what was stored in the first callback.
parameters2_after_training = list(callback2.evaluator.parameters())
assert torch.allclose(parameters2_after_training[0], parameters1[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also check if the optimizer state is restored correctly?


# It's somewhat obsolete, but we can now check that the checkpoint file really contained the optimizer and weights
checkpoint = torch.load(last_checkpoint)
assert "callbacks" in checkpoint
assert SSLOnlineEvaluatorInnerEye in checkpoint["callbacks"]
callback_state = checkpoint["callbacks"][SSLOnlineEvaluatorInnerEye]
assert SSLOnlineEvaluatorInnerEye.OPTIMIZER_STATE_NAME in callback_state
assert SSLOnlineEvaluatorInnerEye.EVALUATOR_STATE_NAME in callback_state


@pytest.mark.gpu
def test_online_evaluator_distributed() -> None:
"""
A very primitive type of test to check if the online evaluator uses the DDP flag correctly.
"""
callback = SSLOnlineEvaluatorInnerEye(class_weights=None,
z_dim=1,
num_classes=2,
dataset="foo",
drop_p=0.2,
learning_rate=1e-5)
assert isinstance(callback.evaluator, Module)
assert not isinstance(callback.evaluator, DistributedDataParallel)
trainer = Trainer()
mock_module = mock.MagicMock(device=torch.device("cpu"))
callback.on_pretrain_routine_start(trainer, mock_module)
assert isinstance(callback.evaluator, Module)
assert not isinstance(callback.evaluator, DistributedDataParallel)
mock_module = mock.MagicMock(device=torch.device("cuda:0"))
trainer = Trainer(accelerator="ddp", gpus=2)
callback.on_pretrain_routine_start(trainer, mock_module)
assert isinstance(callback.evaluator, DistributedDataParallel)
2 changes: 1 addition & 1 deletion hi-ml
Submodule hi-ml updated 0 files