Skip to content

Commit

Permalink
5/n: Extract reference model call to plugins/accelerators (#4773)
Browse files Browse the repository at this point in the history
* Encapsulate extracting reference model within the plugin to allow custom wrapper logic to live within the plugin/accelerators

* Add missing new lines

* Fix call to accelerator

* Removed double blank

* Use accelerator backend

* Handle case where wrapper has not been initialized within the plugin

* Added basic get model tests, add better typing

* Change model name

* Split GPU/DDP test

* Add stronger typing, skip ddp test on windows

* Fix import

* Fix import in dp

* Fixed PEP8 definition

* Add ddp launcher for ddp testing

* Modify accelerator reference model to property, change name to reflect func

* Revert property as this is incorrect.=

* Revert across accelerators

* Modified name to get_model_from_plugin

* Code review changes, fix issue with dp

* Add verb to function getter

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
SeanNaren and tchaton committed Nov 23, 2020
1 parent 6831ba9 commit 404af43
Show file tree
Hide file tree
Showing 10 changed files with 181 additions and 16 deletions.
23 changes: 20 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from enum import Enum
from typing import Any, Optional, Union, List
from typing import Any, Optional, Union

import torch
from torch.optim import Optimizer
Expand All @@ -22,8 +22,8 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.core.lightning import LightningModule
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

if torch.distributed.is_available():
from torch.distributed import ReduceOp
Expand Down Expand Up @@ -208,6 +208,23 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
return self.ddp_plugin.optimizer_state(optimizer)
return optimizer.state_dict()

def get_reference_model(self, model) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
when accessing variable and functions if the accelerator has wrapped the model.
Example::
ref_model = accelerator.get_reference_model(model)
ref_model.training_step(...)
Args:
model: Accelerator model.
Returns: Reference :class:`LightningModule`.
"""
return model

def __getstate__(self):
return {
'trainer': self.trainer,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,6 @@ def sync_tensor(self,
"""
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
7 changes: 7 additions & 0 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch import optim

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.core.step_result import Result
Expand Down Expand Up @@ -172,3 +174,8 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
return model
24 changes: 23 additions & 1 deletion pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -108,3 +108,25 @@ def on_before_forward(self, model, *args):

def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

def get_model_from_plugin(
self,
model: Union[LightningDistributedDataParallel, LightningModule]
) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
when accessing variable and functions outside of the parallel wrapper.
Example::
ref_model = ddp_plugin.get_model_from_plugin(model)
ref_model.training_step(...)
Args:
model: Model with parallel wrapper.
Returns: Reference :class:`LightningModule` within parallel wrapper.
"""
if isinstance(model, LightningDistributedDataParallel):
return model.module
return model
18 changes: 6 additions & 12 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,14 @@
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
"""
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)


class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer

def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
ref_model = model.module
elif isinstance(model, LightningDistributedDataParallel):
ref_model = model.module
else:
ref_model = model
ref_model = self._get_reference_model(model)

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization
Expand All @@ -55,6 +46,9 @@ def copy_trainer_model_properties(self, model):
m.local_rank = self.trainer.local_rank

def get_model(self):
is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel))
model = self.trainer.model.module if is_dp_module else self.trainer.model
return self._get_reference_model(self.trainer.model)

def _get_reference_model(self, model):
if self.trainer.accelerator_backend:
return self.trainer.accelerator_backend.get_reference_model(model)
return model
110 changes: 110 additions & 0 deletions tests/trainer/properties/test_get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import pytest
import torch

from pytorch_lightning import Trainer
from tests.backends.launcher import DDPLauncher
from tests.base.boring_model import BoringModel


class TrainerGetModel(BoringModel):
def on_fit_start(self):
assert self == self.trainer.get_model()

def on_fit_end(self):
assert self == self.trainer.get_model()


def test_get_model(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_get_model_ddp_cpu(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
accelerator='ddp_cpu',
num_processes=2
)
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_get_model_gpu(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
gpus=1
)
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@DDPLauncher.run("--accelerator [accelerator]",
max_epochs=["1"],
accelerator=["ddp", "ddp_spawn"])
def test_get_model_ddp_gpu(tmpdir, args=None):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
gpus=1,
accelerator=args.accelerator
)
trainer.fit(model)
return 1

0 comments on commit 404af43

Please sign in to comment.