Skip to content

Commit

Permalink
Change order of imports so that initialization only happens after req…
Browse files Browse the repository at this point in the history
…uired environment variables have been set.
  • Loading branch information
Caspar van Leeuwen committed Apr 9, 2024
1 parent a6bf34d commit fce2a45
Showing 1 changed file with 30 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.utils.data.distributed
from torchvision import models


# Benchmark settings
parser = argparse.ArgumentParser(description='PyTorch Synthetic Benchmark',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -84,20 +85,9 @@


if args.use_ddp:
from socket import gethostname
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from socket import gethostname

def setup(rank, world_size):
# initialize the process group
if args.cuda:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
# clean up the distributed environment
dist.destroy_process_group()

# world_size = int(os.environ["SLURM_NTASKS"]) ## No longer needed now we pass it as argument?
# If launched with mpirun, get rank from this
Expand All @@ -110,22 +100,46 @@ def cleanup():
err_msg += " and srun as launchers. If you've configured a different launcher for your system"
err_msg += " this test will need to be extended with a method to get it's local rank for that launcher."
print(err_msg)

setup(rank, world_size)
# log(f"Group initialized? {dist.is_initialized()}", rank)
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)

# If launched with srun, you are in a CGROUP with only 1 GPU, so you don't need to set it.
# If launched with mpirun, you see ALL local GPUs on the node, and you need to set which one
# this rank should use.
visible_gpus = torch.cuda.device_count()
if visible_gpus > 1:
print("Listing visible devices")
for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.device(i)}")
local_rank = rank - visible_gpus * (rank // visible_gpus)
torch.cuda.set_device(local_rank)
print("Listing visible devices after setting one")
for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.device(i)}")
# We should also set CUDA_VISIBLE_DEVICES, which gets respected by NCCL
os.environ['CUDA_VISIBLE_DEVICES'] = '%s' % local_rank
print(f"host: {gethostname()}, rank: {rank}, local_rank: {local_rank}")
else:
print(f"host: {gethostname()}, rank: {rank}")


def setup(rank, world_size):

# initialize the process group
if args.cuda:
dist.init_process_group("nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
# clean up the distributed environment
dist.destroy_process_group()

setup(rank, world_size)
# log(f"Group initialized? {dist.is_initialized()}", rank)
if rank == 0: print(f"Group initialized? {dist.is_initialized()}", flush=True)




# This relies on the 'rank' set in the if args.use_horovod or args.use_ddp sections
def log(s, nl=True):
if (args.use_horovod or args.use_ddp) and rank != 0:
Expand Down

0 comments on commit fce2a45

Please sign in to comment.