Skip to content

Commit

Permalink
_
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 1, 2021
1 parent 34859f3 commit f1eb8b1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/legacy/rpc_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class RPCPlugin(DDPPlugin):

def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs):
self.rpc_timeout_sec = rpc_timeout_sec
self.rpc_initialized = False
self._is_rpc_initialized = False
super().__init__(**kwargs)

def init_rpc_connection(self,
Expand All @@ -48,7 +48,7 @@ def init_rpc_connection(self,
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
rpc._set_rpc_timeout(self.rpc_timeout_sec)
self.rpc_initialized = True
self._is_rpc_initialized = True

def rpc_save_model(self,
save_model_fn,
Expand Down Expand Up @@ -86,9 +86,9 @@ def on_accelerator_exit_rpc_process(self, trainer) -> None:
self.exit_rpc_process()

def exit_rpc_process(self):
if self.rpc_initialized:
if self._is_rpc_initialized:
torch.distributed.rpc.shutdown()
self.rpc_initialized = False
self._is_rpc_initialized = False

@property
def return_after_exit_rpc_process(self) -> bool:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/plugins/training_type/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
**kwargs
):
self.rpc_timeout_sec = rpc_timeout_sec
self.rpc_initialized = False
self._is_rpc_initialized = False
super().__init__(
parallel_devices=parallel_devices,
num_nodes=num_nodes,
Expand All @@ -61,7 +61,7 @@ def init_rpc_connection(self, global_rank: int, world_size: int) -> None:
os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000')
rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size)
rpc._set_rpc_timeout(self.rpc_timeout_sec)
self.rpc_initialized = True
self._is_rpc_initialized = True

def rpc_save_model(self, save_model_fn, last_filepath, trainer, pl_module) -> None:
"""
Expand Down Expand Up @@ -95,9 +95,9 @@ def on_accelerator_exit_rpc_process(self) -> None:
self.exit_rpc_process()

def exit_rpc_process(self):
if self.rpc_initialized:
if self._is_rpc_initialized:
torch.distributed.rpc.shutdown()
self.rpc_initialized = False
self._is_rpc_initialized = False

@property
def return_after_exit_rpc_process(self) -> bool:
Expand Down

0 comments on commit f1eb8b1

Please sign in to comment.