Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
linshokaku committed Dec 21, 2023
1 parent 7977b2d commit 60b48fb
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
12 changes: 10 additions & 2 deletions pytorch_pfn_extras/distributed/_initialize.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from datetime import timedelta
from typing import Tuple

import torch
import torch.distributed


def initialize_ompi_environment(
Expand All @@ -13,6 +14,7 @@ def initialize_ompi_environment(
local_rank: int = 0,
addr: str = "localhost",
port: str = "1234",
timeout: int = 1800,
) -> Tuple[int, int, int]:
"""Initialize `torch.distributed` environments with values taken from
OpenMPI.
Expand All @@ -32,6 +34,8 @@ def initialize_ompi_environment(
Defaults to ``"localhost"``
port: The port of the master process of `torch.distributed`.
Defaults to ``"1234"``
timeout: Timeout seconds for `torch.distributed` collective communication.
Defaults to ``1800``.
"""
e = os.environ
backend = backend
Expand Down Expand Up @@ -62,7 +66,11 @@ def initialize_ompi_environment(

if world_size > 1 and not torch.distributed.is_initialized(): # type: ignore
torch.distributed.init_process_group( # type: ignore
backend, init_method=init_method, world_size=world_size, rank=rank
backend,
init_method=init_method,
world_size=world_size,
rank=rank,
timeout=timedelta(seconds=timeout),
)
torch.distributed.barrier() # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def get_trainer(path):
def _init_distributed(use_cuda):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
size, rank, local_rank = distributed.initialize_ompi_environment(
backend="nccl", init_method="env"
backend="nccl",
init_method="env",
timeout=15,
)
else:
pytest.skip("This test requires MPI to run")
Expand Down Expand Up @@ -99,6 +101,8 @@ def test_distributed_snapshot(mpi_tmp_path, saver_rank):
with trainer.run_iteration():
pass
assert 1 == trainer.iteration
if comm_size > 1:
torch.distributed.barrier()
pattern = os.path.join(trainer.out, "snapshot_iter_*")
found = [os.path.basename(path) for path in glob.glob(pattern)]
# the snapshot is generated only for the saver rank
Expand Down Expand Up @@ -130,10 +134,14 @@ def test_distributed_snapshot_autoload(mpi_tmp_path, saver_rank):
)
trainer = get_trainer(mpi_tmp_path)
trainer.extend(snapshot, trigger=(1, "iteration"), priority=2)
if comm_size > 1:
torch.distributed.barrier()
for _ in range(1):
with trainer.run_iteration():
pass
assert 1 == trainer.iteration
if comm_size > 1:
torch.distributed.barrier()
pattern = os.path.join(trainer.out, "snapshot_iter_*")
found = [os.path.basename(path) for path in glob.glob(pattern)]
assert len(found) == 1
Expand Down Expand Up @@ -168,6 +176,8 @@ def test_distributed_snapshot_on_error(mpi_tmp_path, saver_rank):
pass
dummy_tb = dummy_exception.__traceback__
snapshot.on_error(trainer, dummy_exception, dummy_tb)
if comm_size > 1:
torch.distributed.barrier()
pattern = os.path.join(trainer.out, f"snapshot_iter_{saver_rank}_*")
found = [os.path.basename(path) for path in glob.glob(pattern)]
# the snapshot is generated only for the saver rank
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
def _init_distributed(use_cuda):
if "OMPI_COMM_WORLD_SIZE" in os.environ:
size, rank, local_rank = distributed.initialize_ompi_environment(
backend="nccl", init_method="env"
backend="nccl",
init_method="env",
timeout=15,
)
else:
pytest.skip("This test requires MPI to run")
Expand Down Expand Up @@ -305,6 +307,9 @@ def test_sharded_snapshot(mpi_tmp_path):
with trainer.run_iteration():
pass

if comm_size > 1:
torch.distributed.barrier()

pattern = os.path.join(trainer.out, "snapshot_iter_*")
found = [path for path in glob(pattern)]

Expand Down Expand Up @@ -340,6 +345,8 @@ def test_sharded_snapshot_cleanup(mpi_tmp_path):
for _ in range(5):
with trainer.run_iteration():
pass
if comm_size > 1:
torch.distributed.barrier()

pattern = os.path.join(trainer.out, "snapshot_iter_*")
found = [os.path.basename(path) for path in glob(pattern)]
Expand Down Expand Up @@ -368,12 +375,16 @@ def test_sharded_snapshot_autoload(mpi_tmp_path):
autoload=True,
)

trainer = get_trainer("./tmp/test_result", device)
trainer = get_trainer(mpi_tmp_path, device)
trainer.extend(snapshot_extension, trigger=(1, "iteration"), priority=2)
for _ in range(5):
with trainer.run_iteration():
pass
trainer2 = get_trainer("./tmp/test_result", device)

if comm_size > 1:
torch.distributed.barrier()

trainer2 = get_trainer(mpi_tmp_path, device)
snapshot_extension2 = snapshot(
filename=fmt,
snapshot_mode=SnapshotMode.SHARDED,
Expand All @@ -385,7 +396,7 @@ def test_sharded_snapshot_autoload(mpi_tmp_path):
with pytest.raises(AssertionError):
_assert_state_dict_is_eq(trainer2.state_dict(), trainer.state_dict())

trainer3 = get_trainer("./tmp/test_result", device)
trainer3 = get_trainer(mpi_tmp_path, device)
snapshot_extension3 = snapshot(
filename=fmt,
snapshot_mode=SnapshotMode.SHARDED,
Expand Down

0 comments on commit 60b48fb

Please sign in to comment.