-
Notifications
You must be signed in to change notification settings - Fork 2
/
distrib.py
91 lines (77 loc) · 2.93 KB
/
distrib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
import os
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, Subset
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.cuda import device_count
logger = logging.getLogger(__name__)
rank = 0
world_size = device_count()
def init(args):
"""init.
Initialize DDP using the given rendezvous file.
"""
global rank, world_size
if args.ddp:
assert args.rank is not None and args.world_size is not None
rank = args.rank
world_size = args.world_size
if world_size == 1:
return
torch.cuda.set_device(rank)
torch.distributed.init_process_group(
backend=args.ddp_backend,
init_method='file://' + os.path.abspath(args.rendezvous_file),
world_size=world_size,
rank=rank)
logger.debug("Distributed rendezvous went well, rank %d/%d", rank, world_size)
def average(metrics, count=1.):
"""average.
Average all the relevant metrices across processes
`metrics`should be a 1D float32 vector. Returns the average of `metrics`
over all hosts. You can use `count` to control the weight of each worker.
"""
if world_size == 1:
return metrics
tensor = torch.tensor(list(metrics) + [1], device='cuda', dtype=torch.float32)
# tensor = torch.tensor(list(metrics) + [1], dtype=torch.float32)
tensor *= count
# torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
# TODO
return (tensor[:-1] / tensor[-1]).cpu().numpy().tolist()
def wrap(model):
"""wrap.
Wrap a model with DDP if distributed training is enabled.
"""
if world_size == 1:
return model
else:
return DistributedDataParallel(
model,
device_ids=[torch.cuda.current_device()],
output_device=torch.cuda.current_device())
def barrier():
if world_size > 1:
torch.distributed.barrier()
def loader(dataset, *args, shuffle=False, klass=DataLoader, **kwargs):
"""loader.
Create a dataloader properly in case of distributed training.
If a gradient is going to be computed you must set `shuffle=True`.
:param dataset: the dataset to be parallelized
:param args: relevant args for the loader
:param shuffle: shuffle examples
:param klass: loader class
:param kwargs: relevant args
"""
if world_size == 1:
return klass(dataset, *args, shuffle=shuffle, **kwargs)
if shuffle:
# train means we will compute backward, we use DistributedSampler
sampler = DistributedSampler(dataset)
# We ignore shuffle, DistributedSampler already shuffles
return klass(dataset, *args, **kwargs, sampler=sampler)
else:
# We make a manual shard, as DistributedSampler otherwise replicate some examples
dataset = Subset(dataset, list(range(rank, len(dataset), world_size)))
return klass(dataset, *args, shuffle=shuffle)