Skip to content

Commit

Permalink
Automatically set sync_batchnorm for training_type_plugin (#6536)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
Co-authored-by: Kaushik Bokka <kaushikbokka@gmail.com>
(cherry picked from commit 3b72bcc)
  • Loading branch information
amogkam authored and Borda committed Mar 23, 2021
1 parent c3be721 commit da85776
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/connectors/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> Tra
if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
training_type.num_nodes = self.num_nodes

# Automatically set sync_batchnorm if None.
# Useful for custom plugins.
if hasattr(training_type, 'sync_batchnorm') and getattr(training_type, 'sync_batchnorm') is None:
training_type.sync_batchnorm = self.sync_batchnorm

return training_type

def select_accelerator(self) -> Accelerator:
Expand Down
41 changes: 41 additions & 0 deletions tests/plugins/test_custom_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


class CustomParallelPlugin(DDPPlugin):

def __init__(self, **kwargs):
super().__init__(**kwargs)
# Set to None so it will be overwritten by the accelerator connector.
self.sync_batchnorm = None


@RunIf(skip_windows=True)
def test_sync_batchnorm_set(tmpdir):
"""Tests if sync_batchnorm is automatically set for custom plugin."""
model = BoringModel()
plugin = CustomParallelPlugin()
assert plugin.sync_batchnorm is None
trainer = Trainer(
max_epochs=1,
plugins=[plugin],
default_root_dir=tmpdir,
sync_batchnorm=True,
)
trainer.fit(model)
assert plugin.sync_batchnorm is True

0 comments on commit da85776

Please sign in to comment.