-
Notifications
You must be signed in to change notification settings - Fork 23.4k
/
Copy pathspawn.py
194 lines (159 loc) · 7.19 KB
/
spawn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from __future__ import absolute_import, division, print_function, unicode_literals
import multiprocessing
import multiprocessing.connection
import signal
import sys
from . import _prctl_pr_set_pdeathsig
def _wrap(fn, i, args, error_queue):
# prctl(2) is a Linux specific system call.
# On other systems the following function call has no effect.
# This is set to ensure that non-daemonic child processes can
# terminate if their parent terminates before they do.
_prctl_pr_set_pdeathsig(signal.SIGINT)
try:
fn(i, *args)
except KeyboardInterrupt:
pass # SIGINT; Killed by parent, do nothing
except Exception:
# Propagate exception to parent process, keeping original traceback
import traceback
error_queue.put(traceback.format_exc())
sys.exit(1)
# Multiprocessing contexts are introduced at Python 3.4
_supports_context = sys.version_info >= (3, 4)
def _python_version_check():
if not _supports_context:
raise RuntimeError("Requires python 3.4 or higher to use "
"torch.multiprocessing.spawn and "
"torch.multiprocessing.ProcessContext helper "
"to launch multiple processes. If you are using "
"this for distributed training and have a lower "
"version of python, please use "
"torch.distributed.launch instead.")
class ProcessContext:
def __init__(self, processes, error_queues):
_python_version_check()
self.error_queues = error_queues
self.processes = processes
self.sentinels = {
process.sentinel: index
for index, process in enumerate(processes)
}
def pids(self):
return [int(process.pid) for process in self.processes]
def join(self, timeout=None):
r"""
Tries to join one or more processes in this spawn context.
If one of them exited with a non-zero exit status, this function
kills the remaining processes and raises an exception with the cause
of the first process exiting.
Returns ``True`` if all processes have been joined successfully,
``False`` if there are more processes that need to be joined.
Arguments:
timeout (float): Wait this long before giving up on waiting.
"""
# Ensure this function can be called even when we're done.
if len(self.sentinels) == 0:
return True
# Wait for any process to fail or all of them to succeed.
ready = multiprocessing.connection.wait(
self.sentinels.keys(),
timeout=timeout,
)
error_index = None
for sentinel in ready:
index = self.sentinels.pop(sentinel)
process = self.processes[index]
process.join()
if process.exitcode != 0:
error_index = index
break
# Return if there was no error.
if error_index is None:
# Return whether or not all processes have been joined.
return len(self.sentinels) == 0
# Assume failure. Terminate processes that are still alive.
for process in self.processes:
if process.is_alive():
process.terminate()
process.join()
# There won't be an error on the queue if the process crashed.
if self.error_queues[error_index].empty():
exitcode = self.processes[error_index].exitcode
if exitcode < 0:
name = signal.Signals(-exitcode).name
raise Exception(
"process %d terminated with signal %s" %
(error_index, name)
)
else:
raise Exception(
"process %d terminated with exit code %d" %
(error_index, exitcode)
)
original_trace = self.error_queues[error_index].get()
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
msg += original_trace
raise Exception(msg)
class SpawnContext(ProcessContext):
def __init__(self, processes, error_queues):
warnings.warn('SpawnContext is renamed to ProcessContext since 1.4 release.')
super(SpawnContext, self).__init__(self, processes, error_queues)
pass
# Note: [start_processes]
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
# general enough, and backends like XLA can reuse them in Colab notebooks as well.
# Currently we only add this API first, we can consider adding it to documentation as
# needed in the future.
def start_processes(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
_python_version_check()
mp = multiprocessing.get_context(start_method)
error_queues = []
processes = []
for i in range(nprocs):
error_queue = mp.SimpleQueue()
process = mp.Process(
target=_wrap,
args=(fn, i, args, error_queue),
daemon=daemon,
)
process.start()
error_queues.append(error_queue)
processes.append(process)
context = ProcessContext(processes, error_queues)
if not join:
return context
# Loop on join until it returns True or raises an exception.
while not context.join():
pass
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method='spawn'):
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
If one of the processes exits with a non-zero exit status, the
remaining processes are killed and an exception is raised with the
cause of termination. In the case an exception was caught in the
child process, it is forwarded and its traceback is included in
the exception raised in the parent process.
Arguments:
fn (function): Function is called as the entrypoint of the
spawned process. This function must be defined at the top
level of a module so it can be pickled and spawned. This
is a requirement imposed by multiprocessing.
The function is called as ``fn(i, *args)``, where ``i`` is
the process index and ``args`` is the passed through tuple
of arguments.
args (tuple): Arguments passed to ``fn``.
nprocs (int): Number of processes to spawn.
join (bool): Perform a blocking join on all processes.
daemon (bool): The spawned processes' daemon flag. If set to True,
daemonic processes will be created.
start_method (string): The multiprocessing start method to be used
to create new processes. It CUDA is available and used, it must
be set to ``spawn``.
Returns:
None if ``join`` is ``True``,
:class:`~ProcessContext` if ``join`` is ``False``
"""
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')