Skip to content

Commit

Permalink
ref: moved ___step_end hooks (#3130)
Browse files Browse the repository at this point in the history
* moved eval hooks

* moved eval hooks

* moved eval hooks

* moved eval hooks

* moved eval hooks

* moved eval hooks

* moved eval hooks
  • Loading branch information
williamFalcon authored Aug 24, 2020
1 parent c556ee6 commit 0b3cb3c
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 28 deletions.
9 changes: 9 additions & 0 deletions pytorch_lightning/accelerators/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,12 @@ def batch_to_device(self, batch: Any, device: torch.device):
if model is not None:
return model.transfer_batch_to_device(batch, device)
return move_data_to_device(batch, device)

def training_step_end(self, output):
return output

def test_step_end(self, output):
return output

def validation_step_end(self, output):
return output
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
# limitations under the License.

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.base_backend import Accelerator


class CPUBackend(object):
class CPUBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)

def setup(self, model):
# run through amp wrapper
Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -35,10 +37,10 @@
amp = None


class DDP2Backend(object):
class DDP2Backend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.task_idx = None

def setup(self):
Expand Down Expand Up @@ -168,3 +170,18 @@ def validation_step(self, args):
def test_step(self, args):
output = self.training_step(args)
return output

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def validation_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def test_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from hydra.utils import to_absolute_path, get_original_cwd
Expand All @@ -40,10 +41,10 @@
amp = None


class DDPBackend(object):
class DDPBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, find_free_network_port
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from apex import amp
except ImportError:
amp = None


class DDPSpawnBackend(object):
class DDPSpawnBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.mp_queue = None

def setup(self):
Expand Down
21 changes: 19 additions & 2 deletions pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.accelerators.base_backend import Accelerator

try:
from apex import amp
except ImportError:
amp = None


class DataParallelBackend(object):
class DataParallelBackend(Accelerator):

def __init__(self, trainer):
self.trainer = trainer
super().__init__(trainer)
self.model_autocast_original_forward = None

def setup(self, model):
Expand Down Expand Up @@ -113,6 +115,21 @@ def test_step(self, args):
output = self.training_step(args)
return output

def training_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def validation_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def test_step_end(self, output):
if isinstance(output, Result):
output.dp_reduce()
return output

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers
Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,23 @@ def _evaluate(
# ------------------
# EVAL STEP END
# ------------------
# on dp / ddp2 might still want to do something with the batch parts
if self.is_overridden('test_step_end') or self.is_overridden('validation_step_end'):
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)

elif is_result_obj and (self.use_dp or self.use_ddp2):
# result auto reduce
output.dp_reduce()
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)

# ------------------
# Hook: on_eval_batch_end
# ------------------
# callbacks (on __batch_end)
if test_mode:
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)

# ----------------------
# Post processing
# ----------------------
# track outputs for collation
if output is not None:

Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,12 @@ def call_hook(self, hook_name, *args, **kwargs):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)

# if the PL module doesn't have the hook then call the accelator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)

return output


Expand Down
10 changes: 2 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,14 +1201,8 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
# distributed forward
output = self.accelerator_backend.training_step(args)

is_result_obj = isinstance(output, Result)

# allow any mode to define training_step_end
# do something will all the dp outputs (like softmax)
if self.is_overridden('training_step_end'):
output = self.call_hook('training_step_end', output)
elif is_result_obj and (self.use_dp or self.use_ddp2):
output.dp_reduce()
# Training step end
output = self.call_hook('training_step_end', output)

return output

Expand Down

0 comments on commit 0b3cb3c

Please sign in to comment.