Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes test issues on ddp #1017

Merged
merged 17 commits into from
Mar 3, 2020
Merged
22 changes: 22 additions & 0 deletions docs/source/weights_loading.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,19 @@ Lightning can automate saving and loading checkpoints.

Checkpoint saving
-----------------
A Lightning checkpoint has everything needed to restore a training session including:

- 16-bit scaling factor (apex)
- Current epoch
- Global step
- Model state_dict
- State of all optimizers
- State of all learningRate schedulers
- State of all callbacks
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)

Automatic saving
^^^^^^^^^^^^^^^^

Checkpointing is enabled by default to the current working directory.
To change the checkpoint path pass in:
Expand Down Expand Up @@ -59,6 +72,15 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li
def __init__(self, hparams, ...):
self.hparams = hparams

Manual saving
^^^^^^^^^^^^^

To save your own checkpoint call:

.. code-block:: python

model.save_checkpoint(PATH)

Checkpoint Loading
------------------

Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/trainer/callback_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ class TrainerCallbackConfigMixin(ABC):
# the proper values/initialisation should be done in child class
default_save_path: str
logger: Union[LightningLoggerBase, bool]
weights_save_path: str
ckpt_path: str
checkpoint_callback: ModelCheckpoint

@property
@abstractmethod
Expand All @@ -29,6 +32,7 @@ def configure_checkpoint_callback(self):
User provided weights_saved_path
Otherwise use os.getcwd()
"""
ckpt_path = self.default_save_path
if self.checkpoint_callback is True:
# init a default one
if self.logger is not None:
Expand All @@ -44,12 +48,15 @@ def configure_checkpoint_callback(self):
else:
ckpt_path = os.path.join(self.default_save_path, "checkpoints")

self.ckpt_path = ckpt_path
self.checkpoint_callback = ModelCheckpoint(
filepath=ckpt_path
)
elif self.checkpoint_callback is False:
self.checkpoint_callback = None

self.ckpt_path = ckpt_path

if self.checkpoint_callback:
# set the path for the callbacks
self.checkpoint_callback.save_function = self.save_checkpoint
Expand Down
30 changes: 30 additions & 0 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class TrainerDDPMixin(ABC):
use_amp: bool
amp_level: str
use_tpu: bool
default_save_path: str

@property
@abstractmethod
Expand Down Expand Up @@ -340,6 +341,35 @@ def ddp_train(self, gpu_idx, model):
# continue training routine
self.run_pretrain_routine(model)

# when ddp ends, we save the model
self.save_spawn_weights(model)

def save_spawn_weights(self, model):
"""
Dump a temporary checkpoint after ddp ends to get weights out of the process
:param model:
:return:
"""
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)

def load_spawn_weights(self, original_model):
"""
Load the temp weights saved in the process
To recover the trained model from the ddp process we load the saved weights
:param model:
:return:
"""
# load weights saved in ddp
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
loaded_model = original_model.__class__.load_from_checkpoint(path)

# copy loaded weights to old model
original_model.load_state_dict(loaded_model.state_dict())

# remove ddp weights
os.remove(path)

def resolve_root_node_address(self, root_node):
if '[' in root_node:
name = root_node.split('[')[0]
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ def tpu_train(self, tpu_core_idx, model):
log.info(m)
self.run_pretrain_routine(model)

self.save_spawn_weights(model)

def dp_train(self, model):

# CHOOSE OPTIMIZER
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,11 @@ def run_evaluation(self, test_mode: bool = False):

# log results of test
if test_mode:
model.print(prog_bar_metrics)
if self.proc_rank == 0:
print('-' * 100)
print('TEST RESULTS')
print(prog_bar_metrics)
print('-' * 100)

# log metrics
self.log_metrics(log_metrics, {})
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,8 @@ def fit(
else:
self.__set_random_port()
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
self.load_spawn_weights(model)
self.model = model

# 1 gpu or dp option triggers training using DP module
# easier to avoid NCCL issues
Expand All @@ -975,6 +977,8 @@ def fit(
# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)
self.load_spawn_weights(model)
self.model = model

# ON CPU
else:
Expand Down Expand Up @@ -1192,6 +1196,8 @@ def test(self, model: Optional[LightningModule] = None):
if model is not None:
self.model = model
self.fit(model)
elif self.model is not None and (self.use_ddp or self.use_tpu):
self.fit(self.model)
else:
self.run_evaluation(test_mode=True)

Expand Down