Skip to content

Commit

Permalink
Merge a69d3b0 into 3449e2d
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Feb 18, 2021
2 parents 3449e2d + a69d3b0 commit 31706fa
Show file tree
Hide file tree
Showing 21 changed files with 140 additions and 131 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012))


- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035))


### Removed

- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
return optimizer

def _toggle_model(self):
model_ref = self._trainer.get_model()
model_ref = self._trainer.lightning_module
model_ref.toggle_optimizer(self, self._optimizer_idx)

def _untoggle_model(self):
model_ref = self._trainer.get_model()
model_ref = self._trainer.lightning_module
model_ref.untoggle_optimizer(self)

@contextmanager
Expand All @@ -129,7 +129,7 @@ def toggle_model(self, sync_grad: bool = True):
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
trainer = self._trainer
optimizer = self._optimizer
model = trainer.get_model()
model = trainer.lightning_module

with trainer.profiler.profile(profiler_name):
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
Expand Down
67 changes: 34 additions & 33 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@

from abc import ABC
from copy import deepcopy
from typing import Callable, List
from typing import List

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule


class TrainerCallbackHookMixin(ABC):

# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
callbacks: List[Callback] = []
get_model: Callable
lightning_module: LightningModule

def on_before_accelerator_backend_setup(self, model):
"""Called in the beginning of fit and test"""
Expand All @@ -39,7 +40,7 @@ def setup(self, model, stage: str):
def teardown(self, stage: str):
"""Called at the end of fit and test"""
for callback in self.callbacks:
callback.teardown(self, self.get_model(), stage)
callback.teardown(self, self.lightning_module, stage)

def on_init_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
Expand All @@ -54,72 +55,72 @@ def on_init_end(self):
def on_fit_start(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_start(self, self.get_model())
callback.on_fit_start(self, self.lightning_module)

def on_fit_end(self):
"""Called when the trainer initialization begins, model has not yet been set."""
for callback in self.callbacks:
callback.on_fit_end(self, self.get_model())
callback.on_fit_end(self, self.lightning_module)

def on_sanity_check_start(self):
"""Called when the validation sanity check starts."""
for callback in self.callbacks:
callback.on_sanity_check_start(self, self.get_model())
callback.on_sanity_check_start(self, self.lightning_module)

def on_sanity_check_end(self):
"""Called when the validation sanity check ends."""
for callback in self.callbacks:
callback.on_sanity_check_end(self, self.get_model())
callback.on_sanity_check_end(self, self.lightning_module)

def on_train_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.get_model())
callback.on_train_epoch_start(self, self.lightning_module)

def on_train_epoch_end(self, outputs):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.get_model(), outputs)
callback.on_train_epoch_end(self, self.lightning_module, outputs)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.get_model())
callback.on_validation_epoch_start(self, self.lightning_module)

def on_validation_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.get_model())
callback.on_validation_epoch_end(self, self.lightning_module)

def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.get_model())
callback.on_test_epoch_start(self, self.lightning_module)

def on_test_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.get_model())
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.get_model())
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.get_model())
callback.on_epoch_end(self, self.lightning_module)

def on_train_start(self):
"""Called when the train begins."""
for callback in self.callbacks:
callback.on_train_start(self, self.get_model())
callback.on_train_start(self, self.lightning_module)

def on_train_end(self):
"""Called when the train ends."""
for callback in self.callbacks:
callback.on_train_end(self, self.get_model())
callback.on_train_end(self, self.lightning_module)

def on_pretrain_routine_start(self, model):
"""Called when the train begins."""
Expand All @@ -134,74 +135,74 @@ def on_pretrain_routine_end(self, model):
def on_batch_start(self):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_batch_start(self, self.get_model())
callback.on_batch_start(self, self.lightning_module)

def on_batch_end(self):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_batch_end(self, self.get_model())
callback.on_batch_end(self, self.lightning_module)

def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""
for callback in self.callbacks:
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the validation batch ends."""
for callback in self.callbacks:
callback.on_validation_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the test batch begins."""
for callback in self.callbacks:
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)

def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
"""Called when the test batch ends."""
for callback in self.callbacks:
callback.on_test_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)

def on_validation_start(self):
"""Called when the validation loop begins."""
for callback in self.callbacks:
callback.on_validation_start(self, self.get_model())
callback.on_validation_start(self, self.lightning_module)

def on_validation_end(self):
"""Called when the validation loop ends."""
for callback in self.callbacks:
callback.on_validation_end(self, self.get_model())
callback.on_validation_end(self, self.lightning_module)

def on_test_start(self):
"""Called when the test begins."""
for callback in self.callbacks:
callback.on_test_start(self, self.get_model())
callback.on_test_start(self, self.lightning_module)

def on_test_end(self):
"""Called when the test ends."""
for callback in self.callbacks:
callback.on_test_end(self, self.get_model())
callback.on_test_end(self, self.lightning_module)

def on_keyboard_interrupt(self):
"""Called when the training is interrupted by KeyboardInterrupt."""
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.get_model())
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_save_checkpoint(self):
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.get_model())
state = callback.on_save_checkpoint(self, self.lightning_module)
if state:
callback_states[callback_class] = state
return callback_states
Expand All @@ -224,11 +225,11 @@ def on_after_backward(self):
Called after loss.backward() and before optimizers do anything.
"""
for callback in self.callbacks:
callback.on_after_backward(self, self.get_model())
callback.on_after_backward(self, self.lightning_module)

def on_before_zero_grad(self, optimizer):
"""
Called after optimizer.step() and before optimizer.zero_grad().
"""
for callback in self.callbacks:
callback.on_before_zero_grad(self, self.get_model(), optimizer)
callback.on_before_zero_grad(self, self.lightning_module, optimizer)
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)
Expand Down Expand Up @@ -214,7 +214,7 @@ def hpc_save(self, folderpath: str, logger):
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')

# give model a chance to do something on hpc_save
model = self.trainer.get_model()
model = self.trainer.lightning_module
checkpoint = self.dump_checkpoint()

model.on_hpc_save(checkpoint)
Expand Down Expand Up @@ -307,7 +307,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the hyper_parameters and state_dict from the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# dump the module_arguments and state_dict from the model
checkpoint['state_dict'] = model.state_dict()
Expand Down Expand Up @@ -339,7 +339,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.get_model()
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def info(self):
"""
This function provides necessary parameters to properly configure HookResultStore obj
"""
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
return {
"batch_idx": self.trainer.batch_idx,
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
Expand All @@ -252,7 +252,7 @@ def reset_model(self):
"""
This function is used to reset model state at the end of the capture
"""
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
model_ref._results = Result()
model_ref._current_hook_fx_name = None
model_ref._current_fx_name = ''
Expand All @@ -263,7 +263,7 @@ def cache_result(self) -> None:
and store the result object
"""
with self.trainer.profiler.profile("cache_result"):
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module

# extract hook results
hook_result = model_ref._results
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def cached_results(self) -> Union[EpochResultStore, None]:

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
metrics_holder.convert(
self.trainer._device_type == DeviceType.TPU,
model_ref.device if model_ref is not None else model_ref,
Expand All @@ -103,7 +103,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc

def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
# Todo: required argument `testing` is not used
model = self.trainer.get_model()
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
# track batch_size
Expand Down Expand Up @@ -263,7 +263,7 @@ def track_metrics_deprecated(self, deprecated_eval_results):
def evaluation_epoch_end(self, testing):
# Todo: required argument `testing` is not used
# reset dataloader idx
model_ref = self.trainer.get_model()
model_ref = self.trainer.lightning_module
model_ref._current_dataloader_idx = None

# setting `has_batch_loop_finished` to True
Expand Down Expand Up @@ -408,7 +408,7 @@ def log_train_epoch_end_metrics(
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)

model = self.trainer.get_model()
model = self.trainer.lightning_module

epoch_callback_metrics = {}

Expand Down
10 changes: 1 addition & 9 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, trainer):
self.trainer = trainer

def copy_trainer_model_properties(self, model):
ref_model = self._get_reference_model(model)
ref_model = self.trainer.lightning_module or model

automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
self.trainer.train_loop.automatic_optimization = automatic_optimization
Expand All @@ -37,11 +37,3 @@ def copy_trainer_model_properties(self, model):
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
m.precision = self.trainer.precision

def get_model(self):
return self._get_reference_model(self.trainer.model)

def _get_reference_model(self, model):
if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module:
return self.trainer.accelerator_backend.lightning_module
return model
13 changes: 13 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
Expand Down Expand Up @@ -130,3 +131,15 @@ def use_single_gpu(self, val: bool) -> None:
)
if val:
self.accelerator_connector._device_type = DeviceType.GPU


class DeprecatedModelAttributes:

lightning_module = LightningModule

def get_model(self) -> LightningModule:
rank_zero_warn(
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
" and will be removed in v1.4.", DeprecationWarning
)
return self.lightning_module
Loading

0 comments on commit 31706fa

Please sign in to comment.