Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor setup_training and remove test_mode #5388

Merged
merged 14 commits into from
Jan 13, 2021
5 changes: 5 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def __init__(self,
def setup(self, model):
pass

def train(self):
self.trainer.setup_trainer(self.trainer.model)
return self.train_or_test()

def teardown(self):
# Ensure if necessary all processes are finished
self.barrier()
Expand All @@ -66,6 +70,7 @@ def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
else:
self.trainer.train_loop.setup_training()
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
results = self.trainer.train()
return results

Expand Down
12 changes: 1 addition & 11 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch

Expand Down Expand Up @@ -52,16 +52,6 @@ def setup(self, model):

self.trainer.model = model

def train(self):
model = self.trainer.model

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
return results

def _step(self, model_step: Callable, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,6 @@ def ddp_train(self, process_idx, mp_queue, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -200,8 +197,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,6 @@ def ddp_train(self, process_idx, model):
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -299,9 +296,8 @@ def ddp_train(self, process_idx, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ def ddp_train(self, process_idx, mp_queue, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -160,8 +157,7 @@ def ddp_train(self, process_idx, mp_queue, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,6 @@ def ddp_train(self, process_idx, model):

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -191,8 +188,7 @@ def ddp_train(self, process_idx, model):
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
6 changes: 1 addition & 5 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0

self.ddp_plugin.on_after_setup_optimizers(self.trainer)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

Expand All @@ -175,8 +172,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# allow user to configure ddp
model = self.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,6 @@ def __init_nvidia_apex(self, model):

return model

def train(self):
model = self.trainer.model
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

return results

def teardown(self):
# replace the original fwd function
self.trainer.model.forward = self.model_autocast_original_forward
Expand Down
10 changes: 0 additions & 10 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,6 @@ def setup(self, model):

self.trainer.model = model

def train(self):
model = self.trainer.model

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
return results

def _step(self, model_step: Callable, args):
args[0] = self.to_device(args[0])

Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import ExitStack
from typing import Any, Optional, Union, Callable
from typing import Any, Callable, Optional, Union

import torch
from torch.optim.lr_scheduler import _LRScheduler

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HOROVOD_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_only

if HOROVOD_AVAILABLE:
Expand Down Expand Up @@ -106,8 +106,7 @@ def train(self):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

# set up training routine
self.trainer.train_loop.setup_training(self.trainer.model)
self.trainer.setup_trainer(self.trainer.model)

# train or test
results = self.train_or_test()
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
TPU_AVAILABLE,
move_data_to_device,
rank_zero_info,
rank_zero_only,
rank_zero_warn,
TPU_AVAILABLE,
)
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -134,8 +134,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# setup TPU training
self.__setup_tpu_training(model, trainer)

# set up training routine
self.trainer.train_loop.setup_training(model)
self.trainer.setup_trainer(model)

# train or test
results = self.train_or_test()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import os
from pathlib import Path
import re
from pathlib import Path
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self, trainer):
# used to validate checkpointing logic
self.has_trained = False

def restore_weights(self, model: LightningModule) -> None:
def restore_weights(self) -> None:
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
Expand All @@ -64,7 +64,7 @@ def restore_weights(self, model: LightningModule) -> None:
rank_zero_info(f'restored hpc model from: {checkpoint_path}')

# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing:
elif self.trainer.resume_from_checkpoint is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu)

# wait for all to catch up
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import os
from copy import deepcopy
from pprint import pprint
from typing import Iterable, Union

Expand Down Expand Up @@ -211,9 +211,9 @@ def add_progress_bar_metrics(self, metrics):

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result, test_mode):
def track_metrics_deprecated(self, deprecated_eval_results, using_eval_result):
self._track_callback_metrics(deprecated_eval_results, using_eval_result)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results, test_mode)
self.__process_eval_epoch_end_results_and_log_legacy(deprecated_eval_results)

def evaluation_epoch_end(self, testing):
# reset dataloader idx
Expand Down Expand Up @@ -242,7 +242,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):
def get_evaluate_epoch_results(self):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand All @@ -252,7 +252,7 @@ def get_evaluate_epoch_results(self, test_mode):
self.prepare_eval_loop_results()

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
if self.trainer.testing and self.trainer.is_global_zero and self.trainer.verbose_test:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
Expand Down Expand Up @@ -333,7 +333,7 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mode):
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
if self.trainer.running_sanity_check:
return

Expand All @@ -353,7 +353,7 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
callback_metrics = result.callback_metrics

# in testing we don't need the callback metrics
if test_mode:
if self.trainer.testing:
callback_metrics = {}
else:
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
Expand Down
Loading