Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: run_pretrain_routine -> setup_training #3294

Merged
merged 2 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Trainer
:members: fit, test
:noindex:
:exclude-members:
run_pretrain_routine,
setup_training,
_abc_impl,
set_random_port,
_Trainer__set_root_gpu,
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def training_step(self, args):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def ddp_train(self, process_idx, mp_queue, model):
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)

# get original model
model = self.trainer.get_model()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init_nvidia_apex(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def teardown(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def setup(self, model):

def train(self):
model = self.trainer.model
results = self.trainer.run_pretrain_routine(model)
results = self.trainer.setup_training(model)
return results

def training_step(self, args):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/horovod_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train(self):
# Synchronization will be performed explicitly following backward()
stack.enter_context(optimizer.skip_synchronize())

result = self.trainer.run_pretrain_routine(self.trainer.model)
result = self.trainer.setup_training(self.trainer.model)

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
self.__setup_tpu_training(model, trainer)

# Run the pretrain routine
results = trainer.run_pretrain_routine(model)
results = trainer.setup_training(model)

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,6 @@ def num_gpus(self) -> int:
def copy_trainer_model_properties(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
4 changes: 0 additions & 4 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class TrainerDPMixin(ABC):
def call_setup_hook(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def run_pretrain_routine(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
Expand Down
37 changes: 22 additions & 15 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,12 +1118,15 @@ def can_prepare_data(self):
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data

def run_pretrain_routine(self, model: LightningModule):
def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.

Args:
model: The model to run sanity test on.
"""
# --------------------------
# Setup??
# --------------------------
ref_model = model
if self.data_parallel:
ref_model = model.module
Expand Down Expand Up @@ -1151,7 +1154,7 @@ def run_pretrain_routine(self, model: LightningModule):
# wait for all models to restore weights
if self.on_tpu and XLA_AVAILABLE:
# wait for all processes to catch up
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
torch_xla.core.xla_model.rendezvous("pl.Trainer.setup_training")

elif self.use_horovod:
# wait for all processes to catch up
Expand All @@ -1160,6 +1163,9 @@ def run_pretrain_routine(self, model: LightningModule):
# register auto-resubmit when on SLURM
self.register_slurm_signal_handlers()

# --------------------------
# Pre-train
# --------------------------
# on pretrain routine start
self.on_pretrain_routine_start(ref_model)
if self.is_function_implemented('on_pretrain_routine_start'):
Expand All @@ -1179,6 +1185,14 @@ def run_pretrain_routine(self, model: LightningModule):
# restore training and model before hpc is called
self.restore_weights(model)

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# --------------------------
# if test
# --------------------------
# when testing requested only run test and return
if self.testing:
# only load test dataloader for testing
Expand All @@ -1197,22 +1211,15 @@ def run_pretrain_routine(self, model: LightningModule):

return eval_loop_results

# --------------------------
# sanity
# --------------------------
# run a few val batches before training starts
self._run_sanity_check(ref_model, model)

# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# on pretrain routine end
self.on_pretrain_routine_end(ref_model)
if self.is_function_implemented('on_pretrain_routine_end'):
ref_model.on_pretrain_routine_end()

# CORE TRAINING LOOP
# --------------------------
# TRAIN
# --------------------------
self.train()

def _run_sanity_check(self, ref_model, model):
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ class TrainerTrainLoopMixin(ABC):
max_epochs: int
min_epochs: int
on_gpu: bool
root_gpu: ...
use_ddp: bool
use_dp: bool
use_ddp2: bool
Expand Down Expand Up @@ -330,14 +331,13 @@ def has_arg(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""

def train(self):
# add signal handlers for process kills
# def _signal_kill_handler(*args):
# return TrainerTrainLoopMixin.run_training_teardown(self)
#
# orig_signal_handlers = {}
# for sig_name in SIGNAL_TERMINATE:
# orig_signal_handlers[sig_name] = signal.signal(getattr(signal, sig_name),
# _signal_kill_handler)
# TODO: shrink
# clear cache before training
if self.on_gpu and self.root_gpu is not None:
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(f'cuda:{self.root_gpu}'):
torch.cuda.empty_cache()

# get model
model = self.get_model()
Expand Down