diff --git a/CHANGELOG.md b/CHANGELOG.md index ef596be14..b0b088ee4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,7 @@ in inference-only runs when using lightning containers. - ([#558](https://github.com/microsoft/InnerEye-DeepLearning/pull/558)) Fix issue with the CovidModel config where model weights from a finetuning run were incompatible with the model architecture created for non-finetuning runs. - ([#604](https://github.com/microsoft/InnerEye-DeepLearning/pull/604)) Fix issue where runs on a VM would download the dataset even when a local dataset is provided. +- ([#628](https://github.com/microsoft/InnerEye-DeepLearning/pull/628)) SSL SimCLR using the wrong LR schedule when running on multiple nodes - ([#612](https://github.com/microsoft/InnerEye-DeepLearning/pull/612)) SSL online evaluator was not doing distributed training ### Removed diff --git a/InnerEye/ML/SSL/lightning_containers/ssl_container.py b/InnerEye/ML/SSL/lightning_containers/ssl_container.py index edb5412e6..a9d218e69 100644 --- a/InnerEye/ML/SSL/lightning_containers/ssl_container.py +++ b/InnerEye/ML/SSL/lightning_containers/ssl_container.py @@ -151,9 +151,10 @@ def create_model(self) -> LightningModule: model: LightningModule = SimCLRInnerEye(encoder_name=self.ssl_encoder.value, dataset_name=self.ssl_training_dataset_name.value, use_7x7_first_conv_in_resnet=use_7x7_first_conv_in_resnet, - gpus=self.total_num_gpus, num_samples=self.data_module.num_samples, batch_size=self.data_module.batch_size, + gpus=self.num_gpus_per_node(), + num_nodes=self.num_nodes, learning_rate=self.l_rate, max_epochs=self.num_epochs) elif self.ssl_training_type == SSLTrainingType.BYOL: diff --git a/Tests/SSL/test_ssl_containers.py b/Tests/SSL/test_ssl_containers.py index fbd1f456a..111a84f8f 100644 --- a/Tests/SSL/test_ssl_containers.py +++ b/Tests/SSL/test_ssl_containers.py @@ -2,10 +2,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. # ------------------------------------------------------------------------------------------ +import math from pathlib import Path +from typing import Dict from unittest import mock -import math import numpy as np import pandas as pd import pytest @@ -15,7 +16,6 @@ from pytorch_lightning.callbacks import ModelCheckpoint from torch.nn import Module from torch.optim.lr_scheduler import _LRScheduler -from typing import Dict from InnerEye.Common import fixed_paths from InnerEye.Common.common_util import is_windows @@ -29,6 +29,7 @@ from InnerEye.ML.SSL.lightning_modules.ssl_online_evaluator import SSLOnlineEvaluatorInnerEye from InnerEye.ML.SSL.utils import SSLDataModuleType, SSLTrainingType from InnerEye.ML.common import BEST_CHECKPOINT_FILE_NAME_WITH_SUFFIX +from InnerEye.ML.configs.ssl.CIFAR_SSL_configs import CIFAR10SimCLR from InnerEye.ML.configs.ssl.CXR_SSL_configs import CXRImageClassifier from InnerEye.ML.runner import Runner from Tests.ML.configs.lightning_test_containers import DummyContainerWithModel @@ -402,3 +403,26 @@ def test_online_evaluator_distributed() -> None: # We still need to mock DDP here because the constructor relies on having a process group available mock_ddp.assert_called_once_with(mock_sync_result, device_ids=[device]) assert callback.evaluator == mock_ddp_result + + +def test_simclr_batch_size() -> None: + """ + Test if the number of nodes is correctly passed through to the SIMCLR model. After an update of the semantics of + the "gpus" argument in LightningBolts, we had a regression, leading to incorrect use of the cosine + LR scheduler. + """ + with mock.patch("InnerEye.ML.deep_learning_config.TrainerParams.num_gpus_per_node", return_value=1): + with mock.patch("InnerEye.ML.SSL.lightning_containers.ssl_container.get_encoder_output_dim", return_value=1): + container = CIFAR10SimCLR() + num_samples = 100 + batch_size = 10 + container.data_module = mock.MagicMock(num_samples=num_samples, batch_size=batch_size) + assert container.num_nodes == 1 + model1 = container.create_model() + old_iters_per_epoch = model1.train_iters_per_epoch + assert old_iters_per_epoch == num_samples / batch_size + # Increasing the number of nodes should increase effective batch size, and hence reduce number of + # iterations per epoch + container.num_nodes = 2 + model2 = container.create_model() + assert model2.train_iters_per_epoch == old_iters_per_epoch // container.num_nodes # type:ignore