-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Accelerate training (#1168)
* accelerate training * use multi_processes * get args from config * reorganize env setup * import from mmdet * import from mmdet
- Loading branch information
Showing
6 changed files
with
142 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import cv2 | ||
import os | ||
import platform | ||
import warnings | ||
from torch import multiprocessing as mp | ||
|
||
|
||
def setup_multi_processes(cfg): | ||
"""Setup multi-processing environment variables.""" | ||
# set multi-process start method as `fork` to speed up the training | ||
if platform.system() != 'Windows': | ||
mp_start_method = cfg.get('mp_start_method', 'fork') | ||
current_method = mp.get_start_method(allow_none=True) | ||
if current_method is not None and current_method != mp_start_method: | ||
warnings.warn( | ||
f'Multi-processing start method `{mp_start_method}` is ' | ||
f'different from the previous setting `{current_method}`.' | ||
f'It will be force set to `{mp_start_method}`. You can change ' | ||
f'this behavior by changing `mp_start_method` in your config.') | ||
mp.set_start_method(mp_start_method, force=True) | ||
|
||
# disable opencv multithreading to avoid system being overloaded | ||
opencv_num_threads = cfg.get('opencv_num_threads', 0) | ||
cv2.setNumThreads(opencv_num_threads) | ||
|
||
# setup OMP threads | ||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa | ||
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: | ||
omp_num_threads = 1 | ||
warnings.warn( | ||
f'Setting OMP_NUM_THREADS environment variable for each process ' | ||
f'to be {omp_num_threads} in default, to avoid your system being ' | ||
f'overloaded, please further tune the variable for optimal ' | ||
f'performance in your application as needed.') | ||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) | ||
|
||
# setup MKL threads | ||
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: | ||
mkl_num_threads = 1 | ||
warnings.warn( | ||
f'Setting MKL_NUM_THREADS environment variable for each process ' | ||
f'to be {mkl_num_threads} in default, to avoid your system being ' | ||
f'overloaded, please further tune the variable for optimal ' | ||
f'performance in your application as needed.') | ||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import cv2 | ||
import multiprocessing as mp | ||
import os | ||
import platform | ||
from mmcv import Config | ||
|
||
from mmdet3d.utils import setup_multi_processes | ||
|
||
|
||
def test_setup_multi_processes(): | ||
# temp save system setting | ||
sys_start_mehod = mp.get_start_method(allow_none=True) | ||
sys_cv_threads = cv2.getNumThreads() | ||
# pop and temp save system env vars | ||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) | ||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) | ||
|
||
# test config without setting env | ||
config = dict(data=dict(workers_per_gpu=2)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert os.getenv('OMP_NUM_THREADS') == '1' | ||
assert os.getenv('MKL_NUM_THREADS') == '1' | ||
# when set to 0, the num threads will be 1 | ||
assert cv2.getNumThreads() == 1 | ||
if platform.system() != 'Windows': | ||
assert mp.get_start_method() == 'fork' | ||
|
||
# test num workers <= 1 | ||
os.environ.pop('OMP_NUM_THREADS') | ||
os.environ.pop('MKL_NUM_THREADS') | ||
config = dict(data=dict(workers_per_gpu=0)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert 'OMP_NUM_THREADS' not in os.environ | ||
assert 'MKL_NUM_THREADS' not in os.environ | ||
|
||
# test manually set env var | ||
os.environ['OMP_NUM_THREADS'] = '4' | ||
config = dict(data=dict(workers_per_gpu=2)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert os.getenv('OMP_NUM_THREADS') == '4' | ||
|
||
# test manually set opencv threads and mp start method | ||
config = dict( | ||
data=dict(workers_per_gpu=2), | ||
opencv_num_threads=4, | ||
mp_start_method='spawn') | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert cv2.getNumThreads() == 4 | ||
assert mp.get_start_method() == 'spawn' | ||
|
||
# revert setting to avoid affecting other programs | ||
if sys_start_mehod: | ||
mp.set_start_method(sys_start_mehod, force=True) | ||
cv2.setNumThreads(sys_cv_threads) | ||
if sys_omp_threads: | ||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads | ||
else: | ||
os.environ.pop('OMP_NUM_THREADS') | ||
if sys_mkl_threads: | ||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads | ||
else: | ||
os.environ.pop('MKL_NUM_THREADS') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters