Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace .get_model() with explicit .lightning_module #6035

Merged
merged 10 commits into from
Feb 18, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
return optimizer

def _toggle_model(self):
model_ref = self._trainer.get_model()
model_ref = self._trainer.lightning_module
model_ref.toggle_optimizer(self, self._optimizer_idx)

def _untoggle_model(self):
model_ref = self._trainer.get_model()
model_ref = self._trainer.lightning_module
model_ref.untoggle_optimizer(self)

@contextmanager
Expand All @@ -129,7 +129,7 @@ def toggle_model(self, sync_grad: bool = True):
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
trainer = self._trainer
optimizer = self._optimizer
model = trainer.get_model()
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
Expand Down
67 changes: 34 additions & 33 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from abc import ABC
from copy import deepcopy
from typing import Callable, List
from typing import List

from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback


Expand All @@ -24,7 +25,7 @@ class TrainerCallbackHookMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
callbacks: List[Callback] = []
get_model: Callable
lightning_module: LightningModule

def on_before_accelerator_backend_setup(self, model):
"""Called in the beginning of fit and test"""
Expand All @@ -39,7 +40,7 @@ def setup(self, model, stage: str):
def teardown(self, stage: str):
"""Called at the end of fit and test"""
for callback in self.callbacks:
callback.teardown(self, self.get_model(), stage)
callback.teardown(self, self.lightning_module, stage)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to move this outside of the loop to save on unwrapping per callback?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's almost not worth it, because unwrapping doesn't really add any overhead since it's just following references, i.e., model.module in case of DDP for example.
I could be wrong, or it may change in the future. Do you see it as a big overhead?


def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
Expand All @@ -54,72 +55,72 @@ def on_init_end(self):
def on_fit_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())
callback.on_fit_start(self, self.lightning_module)

def on_fit_end(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())
callback.on_fit_end(self, self.lightning_module)

def on_sanity_check_start(self):
"""Called when the validation sanity check starts."""
for callback in self.callbacks:
callback.on_sanity_check_start(self, self.get_model())
callback.on_sanity_check_start(self, self.lightning_module)

def on_sanity_check_end(self):
"""Called when the validation sanity check ends."""
for callback in self.callbacks:
callback.on_sanity_check_end(self, self.get_model())
callback.on_sanity_check_end(self, self.lightning_module)

def on_train_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.get_model())
callback.on_train_epoch_start(self, self.lightning_module)

def on_train_epoch_end(self, outputs):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.get_model(), outputs)
callback.on_train_epoch_end(self, self.lightning_module, outputs)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.get_model())
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.get_model())
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.get_model())
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.get_model())
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.get_model())
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.get_model())
callback.on_epoch_end(self, self.lightning_module)

def on_train_start(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_start(self, self.get_model())
callback.on_train_start(self, self.lightning_module)

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())
callback.on_train_end(self, self.lightning_module)

def on_pretrain_routine_start(self, model):
"""Called when the train begins."""
Expand All @@ -134,74 +135,74 @@ def on_pretrain_routine_end(self, model):
def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_start(self, self.get_model())
callback.on_batch_start(self, self.lightning_module)

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())
callback.on_batch_end(self, self.lightning_module)

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
for callback in self.callbacks:
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
for callback in self.callbacks:
callback.on_validation_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
for callback in self.callbacks:
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
for callback in self.callbacks:
callback.on_test_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_start(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_start(self, self.get_model())
callback.on_validation_start(self, self.lightning_module)

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end(self, self.get_model())
callback.on_validation_end(self, self.lightning_module)

def on_test_start(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_start(self, self.get_model())
callback.on_test_start(self, self.lightning_module)

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())
callback.on_test_end(self, self.lightning_module)

def on_keyboard_interrupt(self):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_save_checkpoint(self):
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.get_model())
state = callback.on_save_checkpoint(self, self.lightning_module)
if state:
callback_states[callback_class] = state
return callback_states
Expand All @@ -224,11 +225,11 @@ def on_after_backward(self):
Called after loss.backward() and before optimizers do anything.
"""
for callback in self.callbacks:
callback.on_after_backward(self, self.get_model())
callback.on_after_backward(self, self.lightning_module)

def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
for callback in self.callbacks:
callback.on_before_zero_grad(self, self.get_model(), optimizer)
callback.on_before_zero_grad(self, self.lightning_module, optimizer)
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)
Expand Down Expand Up @@ -214,7 +214,7 @@ def hpc_save(self, folderpath: str, logger):
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
model = self.trainer.get_model()
model = self.trainer.lightning_module
checkpoint = self.dump_checkpoint()

model.on_hpc_save(checkpoint)
Expand Down Expand Up @@ -307,7 +307,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the hyper_parameters and state_dict from the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# dump the module_arguments and state_dict from the model
checkpoint['state_dict'] = model.state_dict()
Expand Down Expand Up @@ -339,7 +339,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def info(self):
"""
This function provides necessary parameters to properly configure HookResultStore obj
"""
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
return {
"batch_idx": self.trainer.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
Expand All @@ -252,7 +252,7 @@ def reset_model(self):
"""
This function is used to reset model state at the end of the capture
"""
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
model_ref._results = Result()
model_ref._current_hook_fx_name = None
model_ref._current_fx_name = ''
Expand All @@ -263,7 +263,7 @@ def cache_result(self) -> None:
and store the result object
"""
with self.trainer.profiler.profile("cache_result"):
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module

# extract hook results
hook_result = model_ref._results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def cached_results(self) -> Union[EpochResultStore, None]:

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
metrics_holder.convert(
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref,
Expand All @@ -103,7 +103,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc

def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
# Todo: required argument `testing` is not used
model = self.trainer.get_model()
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
# track batch_size
Expand Down Expand Up @@ -263,7 +263,7 @@ def track_metrics_deprecated(self, deprecated_eval_results):
def evaluation_epoch_end(self, testing):
# Todo: required argument `testing` is not used
# reset dataloader idx
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
model_ref._current_dataloader_idx = None

# setting `has_batch_loop_finished` to True
Expand Down Expand Up @@ -408,7 +408,7 @@ def log_train_epoch_end_metrics(
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)

model = self.trainer.get_model()
model = self.trainer.lightning_module

epoch_callback_metrics = {}

Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,6 @@ def copy_trainer_model_properties(self, model):
m.testing = self.trainer.testing
m.precision = self.trainer.precision

def get_model(self):
return self._get_reference_model(self.trainer.model)

def _get_reference_model(self, model):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module:
return self.trainer.accelerator_backend.lightning_module
Expand Down
Loading