Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: moved ___step_end hooks #3130

Merged
merged 7 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ def batch_to_device(self, batch: Any, device: torch.device):
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def training_step_end(self, output):
return output

def test_step_end(self, output):
return output

def validation_step_end(self, output):
return output
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.base_backend import Accelerator


class CPUBackend(object):
class CPUBackend(Accelerator):
williamFalcon marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)

def setup(self, model):
# run through amp wrapper
Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -35,10 +37,10 @@
amp = None


class DDP2Backend(object):
class DDP2Backend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.task_idx = None

def setup(self):
Expand Down Expand Up @@ -168,3 +170,18 @@ def validation_step(self, args):
def test_step(self, args):
output = self.training_step(args)
return output

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def validation_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def test_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -40,10 +41,10 @@
amp = None


class DDPBackend(object):
class DDPBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from apex import amp
except ImportError:
amp = None


class DDPSpawnBackend(object):
class DDPSpawnBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.mp_queue = None

def setup(self):
Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from apex import amp
except ImportError:
amp = None


class DataParallelBackend(object):
class DataParallelBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.model_autocast_original_forward = None

def setup(self, model):
Expand Down Expand Up @@ -113,6 +115,21 @@ def test_step(self, args):
output = self.training_step(args)
return output

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def validation_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def test_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,23 @@ def _evaluate(
# ------------------
# EVAL STEP END
# ------------------
# on dp / ddp2 might still want to do something with the batch parts
if self.is_overridden('test_step_end') or self.is_overridden('validation_step_end'):
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)

elif is_result_obj and (self.use_dp or self.use_ddp2):
# result auto reduce
output.dp_reduce()
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)

# ------------------
# Hook: on_eval_batch_end
# ------------------
# callbacks (on __batch_end)
if test_mode:
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)

# ----------------------
# Post processing
# ----------------------
# track outputs for collation
if output is not None:

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,12 @@ def call_hook(self, hook_name, *args, **kwargs):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)

# if the PL module doesn't have the hook then call the accelator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

return output


Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,14 +1201,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
# distributed forward
output = self.accelerator_backend.training_step(args)

is_result_obj = isinstance(output, Result)

# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden('training_step_end'):
output = self.call_hook('training_step_end', output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
output.dp_reduce()
# Training step end
output = self.call_hook('training_step_end', output)

return output

Expand Down