diff --git a/pandora/bootstrap.py b/pandora/bootstrap.py index d94458b..b01c6b6 100644 --- a/pandora/bootstrap.py +++ b/pandora/bootstrap.py @@ -94,9 +94,10 @@ class ProcessWrapper: TODO: Docstring """ - def __init__(self, func: Callable, args: Iterable[Any]): + def __init__(self, func: Callable, args: Iterable[Any], context: multiprocessing.context.BaseContext): self.func = func self.args = args + self.context = context self.process = None @@ -106,7 +107,7 @@ def __init__(self, func: Callable, args: Iterable[Any]): self.is_paused = False # prevent race conditions when handling a signal - self.lock = multiprocessing.RLock() + self.lock = self.context.RLock() def run(self): with self.lock: @@ -117,7 +118,7 @@ def run(self): with tempfile.NamedTemporaryFile() as result_tmpfile: result_tmpfile = pathlib.Path(result_tmpfile.name) with self.lock: - self.process = multiprocessing.Process( + self.process = self.context.Process( target=functools.partial( _wrapped_func, self.func, self.args, result_tmpfile ), @@ -216,7 +217,9 @@ class ParallelBoostrapProcessManager: """ def __init__(self, func: Callable, args: Iterable[Any]): - self.processes = [ProcessWrapper(func, arg) for arg in args] + self.context = multiprocessing.get_context("spawn") + self.processes = [ProcessWrapper(func, arg, self.context) for arg in args] + def run( self, diff --git a/pandora/embedding_comparison.py b/pandora/embedding_comparison.py index 0c09dd0..5610602 100644 --- a/pandora/embedding_comparison.py +++ b/pandora/embedding_comparison.py @@ -4,6 +4,7 @@ import concurrent.futures import itertools +import multiprocessing import warnings from typing import List, Optional, Tuple @@ -367,7 +368,7 @@ def get_pairwise_stabilities(self, threads: Optional[int] = None) -> pd.Series: Each value is between 0 and 1 with higher values indicating a higher stability. """ - with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool: + with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool: pairwise_stabilities = pool.map( _stability_for_pair, itertools.combinations(enumerate(self.embeddings), r=2), @@ -428,7 +429,7 @@ def get_pairwise_cluster_stabilities( enumerate(self.embeddings), r=2 ) ] - with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool: + with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool: pairwise_cluster_stabilities = pool.map(_cluster_stability_for_pair, args) pairwise_cluster_stabilities = pd.concat(pairwise_cluster_stabilities) @@ -463,7 +464,7 @@ def _get_pairwise_difference( (embedding1, embedding2, sample_ids) for embedding1, embedding2 in itertools.permutations(self.embeddings, r=2) ] - with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool: + with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool: diffs = pool.map(_difference_for_pair, args) return diffs @@ -496,7 +497,7 @@ def get_sample_support_values(self, threads: Optional[int] = None) -> pd.Series: ) args = [(embedding, sample_ids_superset) for embedding in self.embeddings] - with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool: + with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool: embedding_norms = pool.map(_get_embedding_norm, args) denominator = 2 * len(self.embeddings) * np.sum(embedding_norms, axis=0) + 1e-6 diff --git a/pandora/main.py b/pandora/main.py index 9764282..b873cb7 100644 --- a/pandora/main.py +++ b/pandora/main.py @@ -1,7 +1,6 @@ import argparse import datetime import math -import multiprocessing import pathlib import sys import textwrap @@ -134,5 +133,4 @@ def main(): if __name__ == "__main__": - multiprocessing.set_start_method("spawn") main()