diff --git a/pytorch_pfn_extras/distributed/_initialize.py b/pytorch_pfn_extras/distributed/_initialize.py index 9d09c122..1583db58 100644 --- a/pytorch_pfn_extras/distributed/_initialize.py +++ b/pytorch_pfn_extras/distributed/_initialize.py @@ -1,7 +1,8 @@ import os +from datetime import timedelta from typing import Tuple -import torch +import torch.distributed def initialize_ompi_environment( @@ -13,6 +14,7 @@ def initialize_ompi_environment( local_rank: int = 0, addr: str = "localhost", port: str = "1234", + timeout: int = 1800, ) -> Tuple[int, int, int]: """Initialize `torch.distributed` environments with values taken from OpenMPI. @@ -32,6 +34,8 @@ def initialize_ompi_environment( Defaults to ``"localhost"`` port: The port of the master process of `torch.distributed`. Defaults to ``"1234"`` + timeout: Timeout seconds for `torch.distributed` collective communication. + Defaults to ``1800``. """ e = os.environ backend = backend @@ -62,7 +66,11 @@ def initialize_ompi_environment( if world_size > 1 and not torch.distributed.is_initialized(): # type: ignore torch.distributed.init_process_group( # type: ignore - backend, init_method=init_method, world_size=world_size, rank=rank + backend, + init_method=init_method, + world_size=world_size, + rank=rank, + timeout=timedelta(seconds=timeout), ) torch.distributed.barrier() # type: ignore diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py index f1e8c7d9..9113dc48 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_distributed_snapshot.py @@ -42,7 +42,9 @@ def get_trainer(path): def _init_distributed(use_cuda): if "OMPI_COMM_WORLD_SIZE" in os.environ: size, rank, local_rank = distributed.initialize_ompi_environment( - backend="nccl", init_method="env" + backend="nccl", + init_method="env", + timeout=15, ) else: pytest.skip("This test requires MPI to run") @@ -99,6 +101,8 @@ def test_distributed_snapshot(mpi_tmp_path, saver_rank): with trainer.run_iteration(): pass assert 1 == trainer.iteration + if comm_size > 1: + torch.distributed.barrier() pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob.glob(pattern)] # the snapshot is generated only for the saver rank @@ -130,10 +134,14 @@ def test_distributed_snapshot_autoload(mpi_tmp_path, saver_rank): ) trainer = get_trainer(mpi_tmp_path) trainer.extend(snapshot, trigger=(1, "iteration"), priority=2) + if comm_size > 1: + torch.distributed.barrier() for _ in range(1): with trainer.run_iteration(): pass assert 1 == trainer.iteration + if comm_size > 1: + torch.distributed.barrier() pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob.glob(pattern)] assert len(found) == 1 @@ -168,6 +176,8 @@ def test_distributed_snapshot_on_error(mpi_tmp_path, saver_rank): pass dummy_tb = dummy_exception.__traceback__ snapshot.on_error(trainer, dummy_exception, dummy_tb) + if comm_size > 1: + torch.distributed.barrier() pattern = os.path.join(trainer.out, f"snapshot_iter_{saver_rank}_*") found = [os.path.basename(path) for path in glob.glob(pattern)] # the snapshot is generated only for the saver rank diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_sharded_snapshot.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_sharded_snapshot.py index 01aa6d9e..394786f3 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_sharded_snapshot.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_sharded_snapshot.py @@ -16,7 +16,9 @@ def _init_distributed(use_cuda): if "OMPI_COMM_WORLD_SIZE" in os.environ: size, rank, local_rank = distributed.initialize_ompi_environment( - backend="nccl", init_method="env" + backend="nccl", + init_method="env", + timeout=15, ) else: pytest.skip("This test requires MPI to run") @@ -305,6 +307,9 @@ def test_sharded_snapshot(mpi_tmp_path): with trainer.run_iteration(): pass + if comm_size > 1: + torch.distributed.barrier() + pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [path for path in glob(pattern)] @@ -340,6 +345,8 @@ def test_sharded_snapshot_cleanup(mpi_tmp_path): for _ in range(5): with trainer.run_iteration(): pass + if comm_size > 1: + torch.distributed.barrier() pattern = os.path.join(trainer.out, "snapshot_iter_*") found = [os.path.basename(path) for path in glob(pattern)] @@ -368,12 +375,16 @@ def test_sharded_snapshot_autoload(mpi_tmp_path): autoload=True, ) - trainer = get_trainer("./tmp/test_result", device) + trainer = get_trainer(mpi_tmp_path, device) trainer.extend(snapshot_extension, trigger=(1, "iteration"), priority=2) for _ in range(5): with trainer.run_iteration(): pass - trainer2 = get_trainer("./tmp/test_result", device) + + if comm_size > 1: + torch.distributed.barrier() + + trainer2 = get_trainer(mpi_tmp_path, device) snapshot_extension2 = snapshot( filename=fmt, snapshot_mode=SnapshotMode.SHARDED, @@ -385,7 +396,7 @@ def test_sharded_snapshot_autoload(mpi_tmp_path): with pytest.raises(AssertionError): _assert_state_dict_is_eq(trainer2.state_dict(), trainer.state_dict()) - trainer3 = get_trainer("./tmp/test_result", device) + trainer3 = get_trainer(mpi_tmp_path, device) snapshot_extension3 = snapshot( filename=fmt, snapshot_mode=SnapshotMode.SHARDED,