forked from tomgoldstein/loss-landscape
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mpi4pytorch.py
executable file
·120 lines (97 loc) · 3.38 KB
/
mpi4pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
""" mpi4pytorch.py
This module contains convenience methods that make it easy to use mpi4py. The available functions handle memory
allocation and other data formatting tasks so that tensors can be easily reduced/broadcast using 1 line of code.
"""
import numpy as np
import mpi4py
def setup_MPI():
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
# Convert the Object to a Class so that it is possible to add attributes later
class A(mpi4py.MPI.Intracomm):
pass
comm = A(comm)
except:
comm = None
return comm
def print_once(comm, *message):
if not comm or comm.Get_rank()==0:
print (''.join(str(i) for i in message))
def is_master(comm):
return not comm or comm.Get_rank()==0
def allreduce_max(comm, array, display_info=False):
if not comm:
return array
array = np.asarray(array, dtype='d')
total = np.zeros_like(array)
float_min = np.finfo(np.float).min
total.fill(float_min)
if display_info:
print ("(%d): sum=%f : size=%d"%(get_rank(comm), np.sum(array), array.nbytes))
rows = str(comm.gather(array.shape[0]))
cols = str(comm.gather(array.shape[1]))
print_once(comm, "reduce: %s, %s"%(rows, cols))
comm.Allreduce(array, total, op=mpi4py.MPI.MAX)
return total
def allreduce_min(comm, array, display_info=False):
if not comm:
return array
array = np.asarray(array, dtype='d')
total = np.zeros_like(array)
float_max = np.finfo(np.float).max
total.fill(float_max)
if display_info:
print ("(%d): sum=%f : size=%d"%(get_rank(comm), np.sum(array), array.nbytes))
rows = str(comm.gather(array.shape[0]))
cols = str(comm.gather(array.shape[1]))
print_once(comm, "reduce: %s, %s"%(rows, cols))
comm.Allreduce(array, total, op=mpi4py.MPI.MIN)
return total
def reduce_max(comm, array, display_info=False):
if not comm:
return array
array = np.asarray(array, dtype='d')
total = np.zeros_like(array)
float_min = np.finfo(np.float).min
total.fill(float_min)
if display_info:
print ("(%d): sum=%f : size=%d"%(get_rank(comm), np.sum(array), array.nbytes))
rows = str(comm.gather(array.shape[0]))
cols = str(comm.gather(array.shape[1]))
print_once(comm, "reduce: %s, %s"%(rows, cols))
comm.Reduce(array, total, op=mpi4py.MPI.MAX, root=0)
return total
def reduce_min(comm, array, display_info=False):
if not comm:
return array
array = np.asarray(array, dtype='d')
total = np.zeros_like(array)
float_max = np.finfo(np.float).max
total.fill(float_max)
if display_info:
print ("(%d): sum=%f : size=%d"%(get_rank(comm), np.sum(array), array.nbytes))
rows = str(comm.gather(array.shape[0]))
cols = str(comm.gather(array.shape[1]))
print_once(comm, "reduce: %s, %s"%(rows, cols))
comm.Reduce(array, total, op=mpi4py.MPI.MIN, root=0)
return total
def barrier(comm):
if not comm:
return
comm.barrier()
def get_mpi_info():
try:
return mpi4py.MPI.get_vendor()
except ImportError:
return "none"
def get_rank(comm):
try:
return comm.Get_rank()
except ImportError:
return 0
def get_num_procs(comm):
try:
return comm.Get_size()
except ImportError:
return 1