Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle keyboard interrupt for ddp .test() #1019

Merged
merged 26 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
25b3ef1
updated checkpoint docs
williamFalcon Mar 3, 2020
f933d77
updated checkpoint docs
williamFalcon Mar 3, 2020
4044c88
updated checkpoint docs
williamFalcon Mar 3, 2020
049efcf
updated checkpoint docs
williamFalcon Mar 3, 2020
b86ae66
updated checkpoint docs
williamFalcon Mar 3, 2020
c223797
updated checkpoint docs
williamFalcon Mar 3, 2020
cba8fc7
updated checkpoint docs
williamFalcon Mar 3, 2020
5909950
updated checkpoint docs
williamFalcon Mar 3, 2020
d88cc89
updated checkpoint docs
williamFalcon Mar 3, 2020
ddd7d89
updated checkpoint docs
williamFalcon Mar 3, 2020
30857ac
updated checkpoint docs
williamFalcon Mar 3, 2020
e2405db
updated checkpoint docs
williamFalcon Mar 3, 2020
7e6c40f
updated checkpoint docs
williamFalcon Mar 3, 2020
dc92539
updated checkpoint docs
williamFalcon Mar 3, 2020
f3282af
updated checkpoint docs
williamFalcon Mar 3, 2020
b6d63ba
updated checkpoint docs
williamFalcon Mar 3, 2020
8a71158
updated checkpoint docs
williamFalcon Mar 3, 2020
f2ad3c8
updated checkpoint docs
williamFalcon Mar 3, 2020
cd76299
updated checkpoint docs
williamFalcon Mar 3, 2020
6aa7783
updated checkpoint docs
williamFalcon Mar 3, 2020
3f902b4
updated checkpoint docs
williamFalcon Mar 3, 2020
b6cd7fc
updated checkpoint docs
williamFalcon Mar 3, 2020
a0c6cba
updated checkpoint docs
williamFalcon Mar 3, 2020
4f80f16
updated checkpoint docs
williamFalcon Mar 3, 2020
86f0529
updated checkpoint docs
williamFalcon Mar 3, 2020
2b74bd4
updated checkpoint docs
williamFalcon Mar 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ def save_spawn_weights(self, model):
:param model:
:return:
"""
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
if self.proc_rank == 0:
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)

def load_spawn_weights(self, original_model):
"""
Expand All @@ -370,6 +371,8 @@ def load_spawn_weights(self, original_model):
# remove ddp weights
os.remove(path)

return loaded_model

def resolve_root_node_address(self, root_node):
if '[' in root_node:
name = root_node.split('[')[0]
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
from abc import ABC, abstractmethod
import logging as log
import os
import signal

import torch

Expand Down Expand Up @@ -494,6 +495,8 @@ def tpu_train(self, tpu_core_idx, model):
m = f'INIT TPU local core: {self.tpu_local_core_rank}, ' \
f'global rank: {self.tpu_global_core_rank}'
log.info(m)

# continue training routine
self.run_pretrain_routine(model)

self.save_spawn_weights(model)
Expand Down
40 changes: 23 additions & 17 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,14 @@ def fit(
self.ddp_train(task, model)
else:
self.__set_random_port()

# track for predict
self.model = model

# train
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))

# load weights if not interrupted
self.load_spawn_weights(model)
self.model = model

Expand All @@ -976,7 +983,14 @@ def fit(

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'

# track for predict
self.model = model

# train
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)

# load weights if not interrupted
self.load_spawn_weights(model)
self.model = model

Expand Down Expand Up @@ -1192,12 +1206,19 @@ def test(self, model: Optional[LightningModule] = None):
trainer = Trainer()
trainer.test(model)
"""

self.testing = True
if model is not None:
self.model = model
self.fit(model)
elif self.model is not None and (self.use_ddp or self.use_tpu):
self.fit(self.model)
elif self.use_ddp or self.use_tpu:
# attempt to load weights from a spawn
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
test_model = self.model
if os.path.exists(path):
test_model = self.load_spawn_weights(self.model)

self.fit(test_model)
else:
self.run_evaluation(test_mode=True)

Expand All @@ -1217,21 +1238,6 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.

Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import ABC
from subprocess import call
from typing import Union
from copy import deepcopy

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -233,7 +234,9 @@ def dump_checkpoint(self):

# add the hparams and state_dict from the model
model = self.get_model()

checkpoint['state_dict'] = model.state_dict()

if hasattr(model, "hparams"):
checkpoint['hparams'] = vars(model.hparams)
else:
Expand Down