Skip to content

Commit

Permalink
ref: clean up ddp before final fix (#3817)
Browse files Browse the repository at this point in the history
* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix

* ref: clean up ddp before final fix
  • Loading branch information
williamFalcon authored Oct 3, 2020
1 parent 0838c6b commit ed1450a
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def select_accelerator(self):
accelerator_backend = accelerators.DDPCPUSpawnBackend(self.trainer, nprocs=self.trainer.num_processes)

elif self.trainer.distributed_backend == "ddp":
accelerator_backend = accelerators.DDPBackend(self.trainer, mode='ddp')
accelerator_backend = accelerators.DDPBackend(self.trainer)

elif self.trainer.use_dp:
accelerator_backend = accelerators.DataParallelBackend(self.trainer)
Expand Down
31 changes: 9 additions & 22 deletions pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,30 +43,21 @@

class DDPBackend(Accelerator):

def __init__(self, trainer, mode: str = 'ddp'):
def __init__(self, trainer):
super().__init__(trainer)
self.task_idx = None
self._has_spawned_children = False
self.mode = mode
self.dist = LightningDistributed()

def setup(self, model):
if self.mode == 'ddp':
self.__ddp_script_mode_setup()
elif self.mode == 'slurm_ddp':
self.__slurm_setup()
elif self.mode == 'torchelastic_ddp':
self.__torchelastic_setup()

# first track model
self.trainer.model = model

def __slurm_setup(self):
self.task_idx = int(os.environ['SLURM_LOCALID'])
# start the other scripts
self._call_children_scripts()

def __torchelastic_setup(self):
self.task_idx = int(os.environ['LOCAL_RANK'])
def _call_children_scripts(self):

def __ddp_script_mode_setup(self):
assert self.trainer.global_rank == 0
self._check_can_spawn_children()
self._has_spawned_children = True
Expand Down Expand Up @@ -137,12 +128,9 @@ def __ddp_script_mode_setup(self):

def train(self):
model = self.trainer.model
if self.mode == 'ddp':
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
del os.environ['WORLD_SIZE']
return results
else:
self.ddp_train(process_idx=self.task_idx, model=model)
results = self.ddp_train(process_idx=self.task_idx, model=model, is_master=True)
del os.environ['WORLD_SIZE']
return results

def _check_can_spawn_children(self):
if self._has_spawned_children:
Expand Down Expand Up @@ -288,5 +276,4 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
# clean up memory
torch.cuda.empty_cache()

if self.trainer.global_rank == 0:
return results
return results

0 comments on commit ed1450a

Please sign in to comment.