Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ModelSummary validation from train loop on_trainer_init #6610

Merged
merged 6 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,10 @@ def __init__(
self.predict_loop = PredictLoop(self)

# training state
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}"
)
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
self.weights_summary = weights_summary
self.shown_warnings = set()

Expand Down Expand Up @@ -351,7 +355,6 @@ def __init__(
max_steps,
min_steps,
num_sanity_val_steps,
weights_summary,
)
self.evaluation_loop.on_trainer_init()

Expand Down Expand Up @@ -544,10 +547,7 @@ def _pre_training_routine(self):

# print model summary
if self.is_global_zero and self.weights_summary is not None and not self.testing:
if self.weights_summary in ModelSummary.MODES:
ref_model.summarize(mode=self.weights_summary)
else:
raise MisconfigurationException("weights_summary can be None, " + ", ".join(ModelSummary.MODES))
ref_model.summarize(mode=self.weights_summary)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edenafek @awaelchli @justusschock the model summary could be another (small) component we could split out from the core lightning module and offer as a standalone utility. the constructor for the summary object could take:

  • an nn.Module
  • precision
  • device
  • mode

connecting it back into the trainer here would be easy. it's almost there now with the ModelSummary object

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related idea: #4541

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw it's already a standalone utility. It can be used like this:

summary = LayerSummary(model, mode=...)
print(summary)

As you say it could also support nn.Module.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by standalone i mean no dependence on the lightning module. right now the precision and example inputs are looked up via attributes of model inside of model summary, but these could be passed in explicitly to the model summary constructor. it's not a huge deal, but something small for people to see how lightning builds reusable components


# restore training and model before hpc is called
self.checkpoint_connector.restore_weights()
Expand Down
23 changes: 8 additions & 15 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from typing import Optional

import numpy as np
import torch

from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import ParallelPlugin
Expand All @@ -36,7 +36,7 @@

class TrainLoop:

def __init__(self, trainer, multiple_trainloader_mode):
def __init__(self, trainer, multiple_trainloader_mode: str):
self.trainer = trainer
self.early_stopping_accumulator = None
self.checkpoint_accumulator = None
Expand All @@ -53,13 +53,12 @@ def __init__(self, trainer, multiple_trainloader_mode):

def on_trainer_init(
self,
max_epochs,
min_epochs,
max_steps,
min_steps,
num_sanity_val_steps,
weights_summary,
):
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
min_steps: Optional[int],
num_sanity_val_steps: int,
) -> None:
self.trainer.global_step = 0
self.trainer.current_epoch = 0
self.trainer.should_stop = False
Expand All @@ -82,12 +81,6 @@ def on_trainer_init(
else:
self.trainer.num_sanity_val_steps = num_sanity_val_steps

self.trainer.weights_summary = weights_summary
if weights_summary is not None and weights_summary not in ModelSummary.MODES:
raise MisconfigurationException(
f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, got {weights_summary}"
)

@property
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())
Expand Down