Skip to content

Commit

Permalink
ref: added model connector (#3407)
Browse files Browse the repository at this point in the history
* ref: added model connector

* ref: added model connector

* ref: added model connector
  • Loading branch information
williamFalcon authored Sep 9, 2020
1 parent 722c44c commit 8f6b115
Show file tree
Hide file tree
Showing 10 changed files with 61 additions and 122 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_backend == AMPType.APEX:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def ddp_train(self, process_idx, mp_queue, model):
self.trainer.optimizer_frequencies = optimizer_frequencies

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)
self.trainer.model_connector.copy_trainer_model_properties(model)

# AMP -
# run through amp wrapper before going to distributed DP
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,6 @@ def call_setup_hook(self, *args):
def num_gpus(self) -> int:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
97 changes: 0 additions & 97 deletions pytorch_lightning/trainer/distrib_parts.py

This file was deleted.

2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/evaluate_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def is_using_eval_results(self):

def setup(self, model, max_batches, dataloaders):
# copy properties for forward overrides
self.trainer.copy_trainer_model_properties(model)
self.trainer.model_connector.copy_trainer_model_properties(model)

# bookkeeping
self.outputs = []
Expand Down
12 changes: 0 additions & 12 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,22 +195,10 @@ class TrainerEvaluationLoopMixin(ABC):
accelerator_backend: ...
evaluation_loop: EvaluationLoop

@abstractmethod
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reset_test_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def reset_val_dataloader(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def call_hook(self, hook_name, *args, **kwargs):
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
52 changes: 52 additions & 0 deletions pytorch_lightning/trainer/model_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Root module for all distributed operations in Lightning.
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
"""
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)


class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer

def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
ref_model = model.module
elif isinstance(model, LightningDistributedDataParallel):
ref_model = model.module
else:
ref_model = model

for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger
m.use_dp = self.trainer.use_dp
m.use_ddp2 = self.trainer.use_ddp2
m.use_ddp = self.trainer.use_ddp
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
m.use_single_gpu = self.trainer.use_single_gpu
m.use_tpu = self.trainer.use_tpu
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
m.tpu_global_core_rank = self.trainer.tpu_global_core_rank
m.precision = self.trainer.precision
m.global_rank = self.trainer.global_rank
m.local_rank = self.trainer.local_rank
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.trainer.distrib_parts import (TrainerDPMixin)
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
Expand All @@ -57,6 +56,7 @@
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities.model_utils import is_overridden
Expand Down Expand Up @@ -94,7 +94,6 @@ class Trainer(
TrainerModelHooksMixin,
TrainerOptimizersMixin,
TrainerAMPMixin,
TrainerDPMixin,
TrainerDDPMixin,
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
Expand Down Expand Up @@ -380,6 +379,7 @@ def __init__(
self.lr_scheduler_connector = LRSchedulerConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None

Expand Down Expand Up @@ -1060,7 +1060,7 @@ def fit(

def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# bind logger and other properties
self.copy_trainer_model_properties(model)
self.model_connector.copy_trainer_model_properties(model)

# clean hparams
if hasattr(model, 'hparams'):
Expand Down Expand Up @@ -1089,7 +1089,7 @@ def setup_training(self, model: LightningModule):
ref_model.trainer = self

# set local properties on the model
self.copy_trainer_model_properties(ref_model)
self.model_connector.copy_trainer_model_properties(ref_model)

# init amp. Must be done here instead of __init__ to allow ddp to work
if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
num_workers = 8
init_lr = hparams.get('learning_rate') * num_workers

with patch('pytorch_lightning.trainer.distrib_parts.hvd.size') as mock_hvd_size:
with patch('pytorch_lightning.accelerators.horovod_backend.hvd.size') as mock_hvd_size:
mock_hvd_size.return_value = 8

# fit model
Expand Down

0 comments on commit 8f6b115

Please sign in to comment.