Skip to content

Commit

Permalink
Remove redundant None check from spawn plugins (#10855)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 1, 2021
1 parent 3a8b3fc commit 619ef7a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 39 deletions.
39 changes: 20 additions & 19 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,31 +269,32 @@ def determine_ddp_device_ids(self):
return [self.root_device.index]

def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
rank_zero_warn("cleaning up ddp environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.global_rank == 0 and self.mp_queue is not None:
rank_zero_warn("cleaning up ddp environment...")

# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)
if self.global_rank != 0:
return

# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)

def __recover_child_process_weights(self, best_path, last_path):
# transfer back the best path to the trainer
Expand Down
39 changes: 19 additions & 20 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,32 +209,31 @@ def barrier(self, name: Optional[str] = None) -> None:
rendezvous(name)

def __transfer_distrib_spawn_state_on_fit_end(self, trainer: "pl.Trainer", results: Any) -> None:
rank_zero_warn("cleaning up tpu spawn environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None

# requires to compute the state_dict on all processes in case Metrics are present
state_dict = self.lightning_module.state_dict()

if self.mp_queue is not None:
rank_zero_warn("cleaning up tpu spawn environment...")

# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)
# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)

# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
if self.local_rank == 0:
# todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path)
self.mp_queue.put(last_path)
self.mp_queue.put(results)
# adds the `callback_metrics` to the queue
# TODO: Remove the if in v1.7
if is_overridden("add_to_queue", self.lightning_module):
self.lightning_module.add_to_queue(self.mp_queue)
else:
self.add_to_queue(trainer, self.mp_queue)

def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
Expand Down

0 comments on commit 619ef7a

Please sign in to comment.