Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: organize args 2/n #3448

Merged
merged 3 commits into from
Sep 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def on_trainer_init(
num_nodes,
log_gpu_memory,
sync_batchnorm,
benchmark
benchmark,
replace_sampler_ddp
):
# benchmarking
self.trainer.benchmark = benchmark
Expand Down Expand Up @@ -84,6 +85,8 @@ def on_trainer_init(

self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

self.trainer.replace_sampler_ddp = replace_sampler_ddp

def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, ProgressBarBase, ProgressBar
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
Expand All @@ -15,7 +16,13 @@ def on_trainer_init(
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path
):
# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir

# init callbacks
self.trainer.callbacks = callbacks or []

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch):
self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
self.trainer._is_data_prepared = False

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)),
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def on_init_start(
overfit_batches,
fast_dev_run
):

self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
Expand Down
24 changes: 24 additions & 0 deletions pytorch_lightning/trainer/logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
# limitations under the License.
import torch
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import EvalResult, Result
from pprint import pprint
from typing import Iterable


class LoggerConnector:
Expand All @@ -27,6 +29,28 @@ def __init__(self, trainer):
self.logged_metrics = {}
self.progress_bar_metrics = {}

def on_trainer_init(self, logger, log_save_interval, row_log_interval):
# logging
self.configure_logger(logger)
self.trainer.log_save_interval = log_save_interval
self.trainer.row_log_interval = row_log_interval

def configure_logger(self, logger):
if logger is True:
# default logger
self.trainer.logger = TensorBoardLogger(
save_dir=self.trainer.default_root_dir,
version=self.trainer.slurm_job_id,
name='lightning_logs'
)
elif logger is False:
self.trainer.logger = None
else:
if isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
else:
self.trainer.logger = logger

def log_metrics(self, metrics, grad_norm_dic, step=None):
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
Expand Down
16 changes: 0 additions & 16 deletions pytorch_lightning/trainer/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,6 @@ class TrainerLoggingMixin(ABC):
num_gpus: int
logged_metrics: ...

def configure_logger(self, logger):
if logger is True:
# default logger
self.logger = TensorBoardLogger(
save_dir=self.default_root_dir,
version=self.slurm_job_id,
name='lightning_logs'
)
elif logger is False:
self.logger = None
else:
if isinstance(logger, Iterable):
self.logger = LoggerCollection(logger)
else:
self.logger = logger

def metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():
Expand Down
50 changes: 21 additions & 29 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
Expand Down Expand Up @@ -177,6 +178,7 @@ def __init__(
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)

self.tuner = Tuner(self)
self.accelerator_backend = None
Expand All @@ -203,6 +205,7 @@ def __init__(
self.tested_ckpt_path = None

# training state
self.weights_summary = weights_summary
self.model = None
self.datamodule = None
self.testing = False
Expand All @@ -217,26 +220,30 @@ def __init__(
self.running_sanity_check = False
self._state = TrainerState.INITIALIZING

self._default_root_dir = default_root_dir or os.getcwd()
self._weights_save_path = weights_save_path or self._default_root_dir

# init callbacks
self.callback_connector.on_trainer_init(
callbacks,
early_stop_callback,
checkpoint_callback,
progress_bar_refresh_rate,
process_position
process_position,
default_root_dir,
weights_save_path,
)

self.on_init_start()
# init data flags
self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch)

self.gradient_clip_val = gradient_clip_val
self.check_val_every_n_epoch = check_val_every_n_epoch
# hook
self.on_init_start()

if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.track_grad_norm = float(track_grad_norm)
# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps
)

# init accelerator related flags
self.accelerator_connector.on_trainer_init(
Expand All @@ -248,25 +255,16 @@ def __init__(
num_nodes,
log_gpu_memory,
sync_batchnorm,
benchmark
benchmark,
replace_sampler_ddp
)

# -------------------
# CONTINUE
# -------------------
self.weights_summary = weights_summary

# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)

self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size
self._is_data_prepared = False
self.replace_sampler_ddp = replace_sampler_ddp

self.truncated_bptt_steps = truncated_bptt_steps
self.resume_from_checkpoint = resume_from_checkpoint
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
Expand All @@ -276,14 +274,8 @@ def __init__(
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()

# accumulated grads
self.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

# logging
self.configure_logger(logger)
self.log_save_interval = log_save_interval
self.row_log_interval = row_log_interval
# init logger flags
self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval)

# init debugging flags
self.debugging_connector.on_init_start(
Expand Down
45 changes: 45 additions & 0 deletions pytorch_lightning/trainer/training_trick_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.callbacks import GradientAccumulationScheduler


class TrainingTricksConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps):
# gradient clipping
self.trainer.gradient_clip_val = gradient_clip_val

# gradient norm tracking
if not isinstance(track_grad_norm, (int, float)) and track_grad_norm != 'inf':
raise MisconfigurationException("track_grad_norm can be an int, a float or 'inf' (infinity norm).")
self.trainer.track_grad_norm = float(track_grad_norm)

# accumulated grads
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)

self.trainer.truncated_bptt_steps = truncated_bptt_steps

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
9 changes: 0 additions & 9 deletions pytorch_lightning/trainer/training_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,3 @@ def detect_nan_tensors(self, loss: Tensor) -> None:
f'Detected nan and/or inf values in `{name}`.'
' Check your forward pass for numerically unstable operations.'
)

def configure_accumulated_gradients(self, accumulate_grad_batches):
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {0: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")