diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 53b4920bd85ef..226e77f3a58c2 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -320,6 +320,10 @@ def __init__( self.predict_loop = PredictLoop(self) # training state + if weights_summary is not None and weights_summary not in ModelSummary.MODES: + raise MisconfigurationException( + f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}" + ) self.weights_summary = weights_summary self.shown_warnings = set() @@ -351,7 +355,6 @@ def __init__( max_steps, min_steps, num_sanity_val_steps, - weights_summary, ) self.evaluation_loop.on_trainer_init() @@ -544,10 +547,7 @@ def _pre_training_routine(self): # print model summary if self.is_global_zero and self.weights_summary is not None and not self.testing: - if self.weights_summary in ModelSummary.MODES: - ref_model.summarize(mode=self.weights_summary) - else: - raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES)) + ref_model.summarize(mode=self.weights_summary) # restore training and model before hpc is called self.checkpoint_connector.restore_weights() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7e737c424ff26..26a178c4fd77b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,12 +14,12 @@ from contextlib import contextmanager, suppress from copy import copy, deepcopy +from typing import Optional import numpy as np import torch from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.step_result import Result from pytorch_lightning.plugins import ParallelPlugin @@ -36,7 +36,7 @@ class TrainLoop: - def __init__(self, trainer, multiple_trainloader_mode): + def __init__(self, trainer, multiple_trainloader_mode: str): self.trainer = trainer self.early_stopping_accumulator = None self.checkpoint_accumulator = None @@ -53,13 +53,12 @@ def __init__(self, trainer, multiple_trainloader_mode): def on_trainer_init( self, - max_epochs, - min_epochs, - max_steps, - min_steps, - num_sanity_val_steps, - weights_summary, - ): + max_epochs: Optional[int], + min_epochs: Optional[int], + max_steps: Optional[int], + min_steps: Optional[int], + num_sanity_val_steps: int, + ) -> None: self.trainer.global_step = 0 self.trainer.current_epoch = 0 self.trainer.should_stop = False @@ -82,12 +81,6 @@ def on_trainer_init( else: self.trainer.num_sanity_val_steps = num_sanity_val_steps - self.trainer.weights_summary = weights_summary - if weights_summary is not None and weights_summary not in ModelSummary.MODES: - raise MisconfigurationException( - f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}" - ) - @property def num_optimizers(self): num_optimizers = len(self.get_optimizers_iterable())