Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ref: move backends back to individual files (1/5) (ddp_cpu) #3712

Merged
merged 8 commits into from
Sep 29, 2020
7 changes: 6 additions & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,13 @@ def select_accelerator(self):
elif self.trainer.use_tpu:
accelerator_backend = accelerators.TPUBackend(self.trainer)

else:
elif self.trainer.distributed_backend is None:
accelerator_backend = accelerators.CPUBackend(self.trainer)
else:
raise MisconfigurationException(
f'Trainer(distributed_backend={self.trainer.distributed_backend} '
f'is not a supported backend'
)

return accelerator_backend

Expand Down
189 changes: 183 additions & 6 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,191 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from pytorch_lightning.accelerators.ddp_spawn_backend import DDPSpawnBackend
import os
import re

import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp

class DDPCPUSpawnBackend(DDPSpawnBackend):
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.distributed import find_free_network_port

def model_to_device(self, model, process_idx, is_master):
pass
try:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
except ImportError:
HYDRA_AVAILABLE = False
else:
HYDRA_AVAILABLE = True

def get_device_ids(self):

class DDPCPUSpawnBackend(Accelerator):
Copy link
Contributor

@ananthsub ananthsub Sep 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

naming nit: if run with torchelastic, then torchelastic does the spawn, not the backend. also this doesn't resolve the bug here: https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/accelerators/accelerator_connector.py#L142-L143

we basically need a DDPCPUBackend which takes torchelastic_ddp as a mode


def __init__(self, trainer, nprocs):
super().__init__(trainer)
self.mp_queue = None
self.nprocs = nprocs

def setup(self, model):
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))

# pass in a state q
smp = mp.get_context('spawn')
self.mp_queue = smp.SimpleQueue()

self.trainer.model = model

def train(self):
model = self.trainer.model

# train in children process
mp.spawn(self.ddp_train, nprocs=self.nprocs, args=(self.mp_queue, model,))

# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()

# recover the weights of the processes trained in the children
self.__recover_child_process_weights(model, best_path, last_path)
return results

def __recover_child_process_weights(self, model, best_path, last_path):
# transfer back the best path to the trainer
if self.trainer.checkpoint_callback:
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

self.trainer.model = model

def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp
Args:
process_idx:
mp_queue: multiprocessing queue
model:
Returns:
"""
# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self.trainer
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# call sync_bn before .cuda(), configure_apex and configure_ddp
if self.trainer.sync_batchnorm:
model = model.configure_sync_batchnorm(model)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
self.setup_optimizers(model)

# set model properties before going into wrapper
self.trainer.model_connector.copy_trainer_model_properties(model)

# 16-bit
model = self.trainer.precision_connector.connect(model)

# DDP spawn already spawned off each process... no need to do anything
device_ids = None
return device_ids

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()

def training_step(self, args):
if self.trainer.amp_backend == AMPType.NATIVE:
with torch.cuda.amp.autocast():
output = self.trainer.model(*args)
else:
output = self.trainer.model(*args)
return output

def validation_step(self, args):
output = self.training_step(args)
return output

def test_step(self, args):
output = self.training_step(args)
return output

def barrier(self, name: str = None):
torch_distrib.barrier()

def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
# track the best model path
best_model_path = None
if self.trainer.checkpoint_callback is not None:
best_model_path = self.trainer.checkpoint_callback.best_model_path

if self.trainer.global_rank == 0 and mp_queue is not None:
rank_zero_warn('cleaning up ddp environment...')
# todo, pass complete checkpoint as state dictionary
mp_queue.put(best_model_path)
mp_queue.put(results)

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
2 changes: 1 addition & 1 deletion tests/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
model = TopModule()
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dpp_spawn',
distributed_backend='ddp_spawn',
gpus=2,
callbacks=[DeviceAssertCallback()],
max_steps=1,
Expand Down