Skip to content

Commit

Permalink
Merge pull request #781 from alejoe91/mp_context
Browse files Browse the repository at this point in the history
Add mp_context
  • Loading branch information
samuelgarcia authored Jul 4, 2022
2 parents b0cca91 + 46e60bb commit 486bc67
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
3 changes: 3 additions & 0 deletions spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ class BaseRecording(BaseExtractor):
_main_properties = ['group', 'location', 'gain_to_uV', 'offset_to_uV']
_main_features = [] # recording do not handle features

# multiprocessing context preference. If None, then it's OS-dependent
preferred_mp_context = None

def __init__(self, sampling_frequency: float, channel_ids: List, dtype):
BaseExtractor.__init__(self, channel_ids)

Expand Down
29 changes: 22 additions & 7 deletions spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
"""
from pathlib import Path
import numpy as np
import platform

import joblib
import sys
from tqdm.auto import tqdm
import contextlib
from tqdm.auto import tqdm


# import loky
from concurrent.futures import ProcessPoolExecutor
import multiprocessing as mp

_shared_job_kwargs_doc = \
"""**job_kwargs: keyword arguments for parallel processing:
Expand All @@ -28,9 +30,13 @@
Number of jobs to use. With -1 the number of jobs is the same as number of cores
* progress_bar: bool
If True, a progress bar is printed
* mp_context: str or None
Context for multiprocessing. \It can be None (default), "fork" or "spawn".
Note that "fork" is only available on UNIX systems
"""

job_keys = ['n_jobs', 'total_memory', 'chunk_size', 'chunk_memory', 'chunk_duration', 'progress_bar', 'verbose']
job_keys = ['n_jobs', 'total_memory', 'chunk_size', 'chunk_memory', 'chunk_duration', 'progress_bar',
'mp_context', 'verbose']


# from https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution
Expand Down Expand Up @@ -72,7 +78,7 @@ def divide_segment_into_chunks(num_frames, chunk_size):
return chunks


def devide_recording_into_chunks(recording, chunk_size):
def divide_recording_into_chunks(recording, chunk_size):
all_chunks = []
for segment_index in range(recording.get_num_segments()):
num_frames = recording.get_num_samples(segment_index)
Expand Down Expand Up @@ -216,6 +222,9 @@ class ChunkRecordingExecutor:
Size of each chunk in number of samples. If 'total_memory' or 'chunk_memory' are used, it is ignored.
chunk_duration : str or float or None
Chunk duration in s if float or with units if str (e.g. '1s', '500ms')
mp_context : str or None
"fork" (default) or "spawn". If None, the context is taken by the recording.preferred_mp_context.
"fork" is only available on UNIX systems.
job_name: str
Job name
Expand All @@ -227,12 +236,18 @@ class ChunkRecordingExecutor:

def __init__(self, recording, func, init_func, init_args, verbose=False, progress_bar=False, handle_returns=False,
n_jobs=1, total_memory=None, chunk_size=None, chunk_memory=None, chunk_duration=None,
job_name=''):

mp_context=None, job_name=''):
self.recording = recording
self.func = func
self.init_func = init_func
self.init_args = init_args

if mp_context is None:
mp_context = recording.preferred_mp_context
if mp_context is not None and platform.system() == "Windows":
assert mp_context != "fork", "'fork' mp_context not supported on Windows!"

self.mp_context = mp_context

self.verbose = verbose
self.progress_bar = progress_bar
Expand All @@ -253,7 +268,7 @@ def run(self):
"""
Runs the defined jobs.
"""
all_chunks = devide_recording_into_chunks(self.recording, self.chunk_size)
all_chunks = divide_recording_into_chunks(self.recording, self.chunk_size)

if self.handle_returns:
returns = []
Expand All @@ -275,7 +290,6 @@ def run(self):
returns.append(res)
else:
n_jobs = min(self.n_jobs, len(all_chunks))

######## Do you want to limit the number of threads per process?
######## It has to be done to speed up numpy a lot if multicores
######## Otherwise, np.dot will be slow. How to do that, up to you
Expand All @@ -284,6 +298,7 @@ def run(self):
# parallel
with ProcessPoolExecutor(max_workers=n_jobs,
initializer=worker_initializer,
mp_context=mp.get_context(self.mp_context),
initargs=(self.func, self.init_func, self.init_args)) as executor:

results = executor.map(function_wrapper, all_chunks)
Expand Down
10 changes: 8 additions & 2 deletions spikeinterface/core/tests/test_job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,20 @@ def test_ChunkRecordingExecutor():
n_jobs=1, chunk_memory="500k")
processor.run()

# chunk + parralel
# chunk + parallel
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
verbose=True, progress_bar=True,
#~ n_jobs=2, total_memory="200k",
n_jobs=2, chunk_duration="200ms",
job_name='job_name')
processor.run()

# chunk + parallel + spawn
processor = ChunkRecordingExecutor(recording, func, init_func, init_args,
verbose=True, progress_bar=True,
mp_context="spawn",
n_jobs=2, chunk_duration="200ms",
job_name='job_name')
processor.run()

if __name__ == '__main__':
test_divide_segment_into_chunks()
Expand Down

0 comments on commit 486bc67

Please sign in to comment.