Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5/n: Extract reference model call to plugins/accelerators #4773

Merged
merged 30 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
be4c24c
Encapsulate extracting reference model within the plugin to allow cus…
Nov 19, 2020
5101696
Add missing new lines
Nov 19, 2020
078a829
Fix call to accelerator
Nov 19, 2020
aeab93c
Removed double blank
Nov 19, 2020
95a1f19
Use accelerator backend
Nov 19, 2020
84ccdbf
Handle case where wrapper has not been initialized within the plugin
Nov 19, 2020
0864b1c
Added basic get model tests, add better typing
Nov 19, 2020
142a2d3
Change model name
Nov 19, 2020
6e548df
Split GPU/DDP test
Nov 19, 2020
aebb1a3
Add stronger typing, skip ddp test on windows
Nov 19, 2020
fa04807
Fix import
Nov 19, 2020
47e562e
Fix import in dp
Nov 19, 2020
15734e9
Fixed PEP8 definition
Nov 19, 2020
f29f7c5
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 19, 2020
3a7a848
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
10a3a1e
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
e3869c3
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
3a3eaa5
Merge branch 'master' into feature/817-fairscale-5n
tchaton Nov 20, 2020
b44dd75
Add ddp launcher for ddp testing
Nov 20, 2020
6786407
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
9a07f67
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 20, 2020
358f503
Modify accelerator reference model to property, change name to reflec…
Nov 22, 2020
4b16b47
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 22, 2020
977625c
Revert property as this is incorrect.=
Nov 22, 2020
250cd96
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 22, 2020
b506a7e
Revert across accelerators
Nov 22, 2020
86eb0f9
Merge branch 'master' into feature/817-fairscale-5n
SeanNaren Nov 23, 2020
0d8aeef
Modified name to get_model_from_plugin
Nov 23, 2020
b937117
Code review changes, fix issue with dp
Nov 23, 2020
d9289fb
Add verb to function getter
Nov 23, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from enum import Enum
from typing import Any, Optional, Union, List
from typing import Any, Optional, Union

import torch
from torch.optim import Optimizer
Expand All @@ -22,8 +22,8 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.core.lightning import LightningModule
import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log

if torch.distributed.is_available():
from torch.distributed import ReduceOp
Expand Down Expand Up @@ -208,6 +208,23 @@ def optimizer_state(self, optimizer: Optimizer) -> dict:
return self.ddp_plugin.optimizer_state(optimizer)
return optimizer.state_dict()

def get_reference_model(self, model) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
when accessing variable and functions if the accelerator has wrapped the model.

Example::
ref_model = accelerator.get_reference_model(model)
ref_model.training_step(...)

Args:
model: Accelerator model.

Returns: Reference :class:`LightningModule`.

"""
return model

def __getstate__(self):
return {
'trainer': self.trainer,
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,6 @@ def sync_tensor(self,

"""
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
3 changes: 3 additions & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,6 @@ def sync_tensor(self,
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor:
return sync_ddp_if_available(tensor, group, reduce_op)

def get_reference_model(self, model) -> LightningModule:
return self.ddp_plugin.get_model_from_plugin(model)
7 changes: 7 additions & 0 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union

import torch
from torch import optim

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.core.step_result import Result
Expand Down Expand Up @@ -172,3 +174,8 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
return model
24 changes: 23 additions & 1 deletion pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any, Dict, List, Union

import torch.distributed as torch_distrib
from pytorch_lightning import _logger as log
Expand Down Expand Up @@ -108,3 +108,25 @@ def on_before_forward(self, model, *args):

def optimizer_state(self, optimizer: Optimizer) -> dict:
return optimizer.state_dict()

def get_model_from_plugin(
self,
model: Union[LightningDistributedDataParallel, LightningModule]
) -> LightningModule:
"""
Override to modify returning base :class:`LightningModule`
when accessing variable and functions outside of the parallel wrapper.

Example::
ref_model = ddp_plugin.get_model_from_plugin(model)
ref_model.training_step(...)

Args:
model: Model with parallel wrapper.

Returns: Reference :class:`LightningModule` within parallel wrapper.

"""
if isinstance(model, LightningDistributedDataParallel):
return model.module
return model
18 changes: 6 additions & 12 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,14 @@
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.

"""
from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)


class ModelConnector:
def __init__(self, trainer):
self.trainer = trainer

def copy_trainer_model_properties(self, model):
if isinstance(model, LightningDataParallel):
ref_model = model.module
elif isinstance(model, LightningDistributedDataParallel):
ref_model = model.module
else:
ref_model = model
ref_model = self._get_reference_model(model)

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization
Expand All @@ -55,6 +46,9 @@ def copy_trainer_model_properties(self, model):
m.local_rank = self.trainer.local_rank

def get_model(self):
is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel))
model = self.trainer.model.module if is_dp_module else self.trainer.model
return self._get_reference_model(self.trainer.model)

def _get_reference_model(self, model):
if self.trainer.accelerator_backend:
return self.trainer.accelerator_backend.get_reference_model(model)
return model
110 changes: 110 additions & 0 deletions tests/trainer/properties/test_get_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys

import pytest
import torch

from pytorch_lightning import Trainer
from tests.backends.launcher import DDPLauncher
Borda marked this conversation as resolved.
Show resolved Hide resolved
from tests.base.boring_model import BoringModel


class TrainerGetModel(BoringModel):
def on_fit_start(self):
assert self == self.trainer.get_model()

def on_fit_end(self):
assert self == self.trainer.get_model()


def test_get_model(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
)
trainer.fit(model)


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_get_model_ddp_cpu(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
accelerator='ddp_cpu',
num_processes=2
)
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_get_model_gpu(tmpdir):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
gpus=1
)
trainer.fit(model)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
@DDPLauncher.run("--accelerator [accelerator]",
max_epochs=["1"],
accelerator=["ddp", "ddp_spawn"])
def test_get_model_ddp_gpu(tmpdir, args=None):
"""
Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators
"""

model = TrainerGetModel()

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
gpus=1,
accelerator=args.accelerator
)
trainer.fit(model)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
return 1