Skip to content

Commit

Permalink
model_ref -> lightning_module
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 10, 2021
1 parent a4d49e4 commit cdf5edd
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _optimizer_step(
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
model_ref = self.trainer.lightning_module
lightning_module = self.trainer.lightning_module

is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE
Expand All @@ -257,7 +257,7 @@ def _optimizer_step(
self.optim_progress.optimizer.step.increment_ready()

# model hook
model_ref.optimizer_step(
lightning_module.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
Expand Down Expand Up @@ -303,16 +303,16 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos
A ``ClosureResult`` containing the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module
lightning_module = self.trainer.lightning_module

with self.trainer.profiler.profile("model_forward"):

step_kwargs = _build_training_step_kwargs(
model_ref, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
)

# manually capture logged metrics
model_ref._current_fx_name = "training_step"
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
Expand All @@ -321,9 +321,9 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Clos

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(model_ref, training_step_output)
_check_training_step_output(lightning_module, training_step_output)

self._hiddens = _extract_hiddens(training_step_output, model_ref.truncated_bptt_steps)
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

Expand Down

0 comments on commit cdf5edd

Please sign in to comment.