Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu idx #2796

Merged
merged 19 commits into from
Aug 2, 2020
1 change: 1 addition & 0 deletions pytorch_lightning/accelerator_backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from pytorch_lightning.accelerator_backends.dp_backend import DataParallelBackend
from pytorch_lightning.accelerator_backends.ddp_spawn_backend import DDPSpawnBackend
from pytorch_lightning.accelerator_backends.cpu_backend import CPUBackend
from pytorch_lightning.accelerator_backends.ddp_backend_temp import DDPBackendTemp
213 changes: 213 additions & 0 deletions pytorch_lightning/accelerator_backends/ddp_backend_temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import os
import torch
import subprocess
import sys
from time import sleep
import numpy as np
from os.path import abspath
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning import _logger as log
from typing import Optional

try:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
except ImportError:
HYDRA_AVAILABLE = False
else:
HYDRA_AVAILABLE = True

try:
from apex import amp
except ImportError:
APEX_AVAILABLE = False
else:
APEX_AVAILABLE = True


class DDPBackendTemp(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the Temp for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a temporary internal refactor


def __init__(self, trainer):
self.trainer = trainer

def spawn_ddp_children(self, model):
port = os.environ['MASTER_PORT']

master_address = '127.0.0.1' if 'MASTER_ADDR' not in os.environ else os.environ['MASTER_ADDR']
os.environ['MASTER_PORT'] = f'{port}'
os.environ['MASTER_ADDR'] = f'{master_address}'

# allow the user to pass the node rank
node_rank = '0'
if 'NODE_RANK' in os.environ:
node_rank = os.environ['NODE_RANK']
if 'GROUP_RANK' in os.environ:
node_rank = os.environ['GROUP_RANK']

os.environ['NODE_RANK'] = node_rank
os.environ['LOCAL_RANK'] = '0'

# when user is using hydra find the absolute path
path_lib = abspath if not HYDRA_AVAILABLE else to_absolute_path

# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception as e:
full_path = abspath(command[0])

command[0] = full_path
# use the same python interpreter and actually running
command = [sys.executable] + command

# since this script sets the visible devices we replace the gpus flag with a number
num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__()

if '--gpus' in command:
gpu_flag_idx = command.index('--gpus')
command[gpu_flag_idx + 1] = f'{num_gpus}'

os.environ['WORLD_SIZE'] = f'{num_gpus * self.trainer.num_nodes}'

self.trainer.interactive_ddp_procs = []
for local_rank in range(1, self.trainer.num_processes):
env_copy = os.environ.copy()
env_copy['LOCAL_RANK'] = f'{local_rank}'

# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
if HYDRA_AVAILABLE:
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
self.trainer.interactive_ddp_procs.append(proc)

# starting all processes at once can cause issues
# with dataloaders delay between 1-10 seconds
delay = np.random.uniform(1, 5, 1)[0]
sleep(delay)

local_rank = 0
results = self.ddp_train(local_rank, mp_queue=None, model=model, is_master=True)
del os.environ['WORLD_SIZE']

return results

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:
is_master:
proc_offset:

Returns:

"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()

# determine which process we are and world size
if self.trainer.use_ddp:
self.trainer.local_rank = process_idx
self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx
self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes

elif self.trainer.use_ddp2:
self.trainer.local_rank = self.trainer.node_rank
self.trainer.global_rank = self.trainer.node_rank
self.trainer.world_size = self.trainer.num_nodes

# set warning rank
rank_zero_only.rank = self.trainer.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
model.trainer = self
model.init_ddp_connection(
self.trainer.global_rank,
self.trainer.world_size,
self.trainer.is_slurm_managing_tasks
)

# call setup after the ddp process has connected
self.trainer.call_setup_hook(model)

# on world_size=0 let everyone know training is starting
if self.trainer.is_global_zero:
log.info('-' * 100)
log.info(f'distributed_backend={self.trainer.distributed_backend}')
log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes')
log.info('-' * 100)

# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
self.trainer.optimizers = optimizers
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

# MODEL
# copy model to each gpu
if self.trainer.on_gpu:
gpu_idx = process_idx

# when using ddp, the master process (proc 0) continues running as the main one
# this means that the local rank will always be 0
# (even if cuda visible devices has other visible gpus)
# this means that the master process needs to pull the 0th visible index as the device number
if is_master:
available_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(available_gpus[self.trainer.local_rank])

self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)

# set model properties before going into wrapper
self.trainer.copy_trainer_model_properties(model)

# AMP
# run through amp wrapper before going to distributed DP
# TODO: remove with dropping NVIDIA AMP support
if self.trainer.use_amp and not NATIVE_AMP_AVALAIBLE:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)

# DDP2 uses all GPUs on the machine
if self.trainer.distributed_backend == 'ddp' or self.trainer.distributed_backend == 'ddp_spawn':
device_ids = [self.trainer.root_gpu]
elif self.trainer.use_ddp2:
device_ids = self.trainer.data_parallel_device_ids
else: # includes ddp_cpu
device_ids = None

# allow user to configure ddp
model = model.configure_ddp(model, device_ids)

# continue training routine
results = self.trainer.run_pretrain_routine(model)

# get original model
model = self.trainer.get_model()

# persist info in ddp_spawn
self.trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)

# clean up memory
torch.cuda.empty_cache()

if self.trainer.global_rank == 0 and self.trainer.distributed_backend not in ['ddp_spawn', 'ddp_cpu']:
return results
12 changes: 1 addition & 11 deletions pytorch_lightning/accelerator_backends/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,18 @@ def teardown(self, model):
self.trainer.model = model
return results

def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0):
def ddp_train(self, process_idx, mp_queue, model):
"""
Entry point for ddp

Args:
process_idx:
mp_queue: multiprocessing queue
model:
is_master:
proc_offset:

Returns:

"""
# offset the process id if requested
process_idx = process_idx + proc_offset

# show progressbar only on progress_rank 0
if (self.trainer.node_rank != 0 or process_idx != 0) and self.trainer.progress_bar_callback is not None:
self.trainer.progress_bar_callback.disable()
Expand Down Expand Up @@ -126,11 +121,6 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# copy model to each gpu
if self.trainer.on_gpu:
gpu_idx = process_idx
if is_master:
# source of truth is cuda for gpu idx
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
gpu_idx = int(gpus[self.trainer.local_rank])

self.trainer.root_gpu = gpu_idx
torch.cuda.set_device(self.trainer.root_gpu)
model.cuda(self.trainer.root_gpu)
Expand Down
Loading