-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmpi_test.py
46 lines (33 loc) · 1.43 KB
/
mpi_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import argparse
import os
import sys
parser = argparse.ArgumentParser()
parser.add_argument('--name', type=str, default='mpi_test', help="job name")
parser.add_argument('--instance_type', type=str, default="c5.large")
parser.add_argument('--num_tasks', type=int, default=2)
parser.add_argument('--image_name', type=str, default='Deep Learning AMI (Ubuntu) Version 23.0')
parser.add_argument('--spot', action='store_true',
help='use spot instead of regular instances')
parser.add_argument('--nproc_per_node', type=int, default=1)
parser.add_argument('--conda_env', type=str, default='pytorch_p36')
parser.add_argument('--skip_setup', action='store_true')
parser.add_argument('--role', type=str, default='launcher',
help='internal flag, launcher or worker')
args = parser.parse_args()
def launcher():
import ncluster
import util
job = ncluster.make_job(**vars(args))
job.rsync('.')
task0 = job.tasks[0]
hosts_str, hosts_file_str = util.setup_mpi(job, max_slots=1)
task0.write('hosts.slots', hosts_file_str)
script_fn = os.path.basename(__file__)
task0.run(f'mpirun -n 2 -N 1 --hostfile hosts.slots python {script_fn} --role=worker', stream_output=True)
def main():
if args.role == "launcher":
launcher()
elif args.role == "worker":
print(f"{os.uname()[1]} {os.environ.get('RANK', -1)} {' '.join(sys.argv)}")
if __name__ == '__main__':
main()