Skip to content

Commit

Permalink
ref: decouple apex second attemp part 4/n (#4056)
Browse files Browse the repository at this point in the history
* ref: decouple apex second attemp part 4/n

* ref: decouple apex second attemp part 4/n

* Update lightning.py

* ref: decouple apex second attemp part 4/n
  • Loading branch information
williamFalcon committed Oct 10, 2020
1 parent 3a6717c commit ce2edf1
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/base_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def validation_step_end(self, output):
def process_dataloader(self, dataloader):
return dataloader

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
def backward(self, closure_loss, optimizer, *args, **kwargs):
if self.trainer.precision == 16:
closure_loss = self.trainer.precision_connector.backend.backward(closure_loss, optimizer, *args, **kwargs)
else:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def test_step(self, args):
output = self.trainer.model.test_step(*args)
return output

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
super().backward(closure_loss, optimizer, opt_idx, *args, **kwargs)
def backward(self, closure_loss, optimizer, *args, **kwargs):
super().backward(closure_loss, optimizer, *args, **kwargs)
optimizer.synchronize()

def on_train_epoch_end(self, outputs):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __setup_tpu_training(self, model: LightningModule, trainer):
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')

def backward(self, closure_loss, optimizer, opt_idx, *args, **kwargs):
def backward(self, closure_loss, optimizer, *args, **kwargs):
# do backward pass
closure_loss.backward(*args, **kwargs)

Expand Down
23 changes: 0 additions & 23 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,29 +295,6 @@ def on_after_backward(self):
"""

def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
"""
Override backward with your own implementation if you need to.
Args:
trainer: Pointer to the trainer
loss: Loss is already scaled by accumulated grads
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used
Called to perform backward step.
Feel free to override as needed.
The loss passed in has already been scaled for accumulated gradients if requested.
Example::
def backward(self, trainer, loss, optimizer, optimizer_idx):
loss.backward()
"""
loss.backward()


class DataHooks:
def prepare_data(self) -> None:
Expand Down
20 changes: 20 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def __init__(self, *args, **kwargs):
self._results: Result = None
self._current_fx_name = ''

def optimizers(self):
return self.trainer.optimizers

@property
def example_input_array(self) -> Any:
return self._example_input_array
Expand Down Expand Up @@ -1034,6 +1037,23 @@ def configure_optimizers(self):
"`configure_optimizers` must be implemented to be used with the Lightning Trainer"
)

def backward(self, loss: Tensor, optimizer: Optimizer, *args, **kwargs) -> None:
"""
Call this directly from your training_step when doing optimizations manually.
By using this we can ensure that all the proper scaling when using 16-bit etc has been done for you
This function forwards all args to the .backward() call as well.
Example::
def training_step(...):
(opt_a, opt_b) = self.optimizers()
loss = ...
# automatically applies scaling, etc...
self.backward(loss, opt_a)
"""
self.trainer.train_loop.backward(loss, optimizer, *args, **kwargs)

def optimizer_step(
self,
epoch: int,
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,21 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,

# backward pass
with self.trainer.profiler.profile('model_backward'):
result.closure_loss = self.trainer.accelerator_backend.backward(result.closure_loss, optimizer, opt_idx)
self.backward(result, optimizer)

# hook
self.on_after_backward(result.training_step_output, batch_idx, result.loss)

return result

def backward(self, result, optimizer, *args, **kwargs):
result.closure_loss = self.trainer.accelerator_backend.backward(
result.closure_loss,
optimizer,
*args,
**kwargs
)

def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = (self.trainer.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0
num_training_batches_reached = (self.trainer.batch_idx + 1) == self.trainer.num_training_batches
Expand Down

0 comments on commit ce2edf1

Please sign in to comment.