Skip to content

Commit

Permalink
handle keyboard interrupt for ddp .test() (Lightning-AI#1019)
Browse files Browse the repository at this point in the history
* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs

* updated checkpoint docs
  • Loading branch information
williamFalcon authored and tullie committed Apr 3, 2020
1 parent 8f35e4e commit e688c6a
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/distrib_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,9 @@ def save_spawn_weights(self, model):
:param model:
:return:
"""
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)
if self.proc_rank == 0:
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
self.save_checkpoint(path)

def load_spawn_weights(self, original_model):
"""
Expand All @@ -370,6 +371,8 @@ def load_spawn_weights(self, original_model):
# remove ddp weights
os.remove(path)

return loaded_model

def resolve_root_node_address(self, root_node):
if '[' in root_node:
name = root_node.split('[')[0]
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@
from abc import ABC, abstractmethod
import logging as log
import os
import signal

import torch

Expand Down Expand Up @@ -494,6 +495,8 @@ def tpu_train(self, tpu_core_idx, model):
m = f'INIT TPU local core: {self.tpu_local_core_rank}, ' \
f'global rank: {self.tpu_global_core_rank}'
log.info(m)

# continue training routine
self.run_pretrain_routine(model)

self.save_spawn_weights(model)
Expand Down
40 changes: 23 additions & 17 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,14 @@ def fit(
self.ddp_train(task, model)
else:
self.__set_random_port()

# track for predict
self.model = model

# train
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))

# load weights if not interrupted
self.load_spawn_weights(model)
self.model = model

Expand All @@ -976,7 +983,14 @@ def fit(

# COLAB_GPU is an env var available by default in Colab environments.
start_method = 'fork' if os.getenv('COLAB_GPU') else 'spawn'

# track for predict
self.model = model

# train
xmp.spawn(self.tpu_train, args=(model,), nprocs=self.num_tpu_cores, start_method=start_method)

# load weights if not interrupted
self.load_spawn_weights(model)
self.model = model

Expand Down Expand Up @@ -1192,12 +1206,19 @@ def test(self, model: Optional[LightningModule] = None):
trainer = Trainer()
trainer.test(model)
"""

self.testing = True
if model is not None:
self.model = model
self.fit(model)
elif self.model is not None and (self.use_ddp or self.use_tpu):
self.fit(self.model)
elif self.use_ddp or self.use_tpu:
# attempt to load weights from a spawn
path = os.path.join(self.default_save_path, '__temp_weight_ddp_end.ckpt')
test_model = self.model
if os.path.exists(path):
test_model = self.load_spawn_weights(self.model)

self.fit(test_model)
else:
self.run_evaluation(test_mode=True)

Expand All @@ -1217,21 +1238,6 @@ def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


class _PatchDataLoader(object):
r'''
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
'''
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader


def _set_dataloader(model, dataloader, attribute):
r'''
Check dataloaders passed to .fit() method if they are pytorch DataLoader
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import ABC
from subprocess import call
from typing import Union
from copy import deepcopy

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -233,7 +234,9 @@ def dump_checkpoint(self):

# add the hparams and state_dict from the model
model = self.get_model()

checkpoint['state_dict'] = model.state_dict()

if hasattr(model, "hparams"):
checkpoint['hparams'] = vars(model.hparams)
else:
Expand Down

0 comments on commit e688c6a

Please sign in to comment.