Skip to content

Commit

Permalink
wip: context
Browse files Browse the repository at this point in the history
  • Loading branch information
tschuelia committed Oct 25, 2023
1 parent 6150e35 commit 176c489
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
11 changes: 7 additions & 4 deletions pandora/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class ProcessWrapper:
TODO: Docstring
"""

def __init__(self, func: Callable, args: Iterable[Any]):
def __init__(self, func: Callable, args: Iterable[Any], context: multiprocessing.context.BaseContext):
self.func = func
self.args = args
self.context = context

self.process = None

Expand All @@ -106,7 +107,7 @@ def __init__(self, func: Callable, args: Iterable[Any]):
self.is_paused = False

# prevent race conditions when handling a signal
self.lock = multiprocessing.RLock()
self.lock = self.context.RLock()

def run(self):
with self.lock:
Expand All @@ -117,7 +118,7 @@ def run(self):
with tempfile.NamedTemporaryFile() as result_tmpfile:
result_tmpfile = pathlib.Path(result_tmpfile.name)
with self.lock:
self.process = multiprocessing.Process(
self.process = self.context.Process(
target=functools.partial(
_wrapped_func, self.func, self.args, result_tmpfile
),
Expand Down Expand Up @@ -216,7 +217,9 @@ class ParallelBoostrapProcessManager:
"""

def __init__(self, func: Callable, args: Iterable[Any]):
self.processes = [ProcessWrapper(func, arg) for arg in args]
self.context = multiprocessing.get_context("spawn")
self.processes = [ProcessWrapper(func, arg, self.context) for arg in args]


def run(
self,
Expand Down
9 changes: 5 additions & 4 deletions pandora/embedding_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import concurrent.futures
import itertools
import multiprocessing
import warnings
from typing import List, Optional, Tuple

Expand Down Expand Up @@ -367,7 +368,7 @@ def get_pairwise_stabilities(self, threads: Optional[int] = None) -> pd.Series:
Each value is between 0 and 1 with higher values indicating a higher stability.
"""
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool:
with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool:
pairwise_stabilities = pool.map(
_stability_for_pair,
itertools.combinations(enumerate(self.embeddings), r=2),
Expand Down Expand Up @@ -428,7 +429,7 @@ def get_pairwise_cluster_stabilities(
enumerate(self.embeddings), r=2
)
]
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool:
with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool:
pairwise_cluster_stabilities = pool.map(_cluster_stability_for_pair, args)

pairwise_cluster_stabilities = pd.concat(pairwise_cluster_stabilities)
Expand Down Expand Up @@ -463,7 +464,7 @@ def _get_pairwise_difference(
(embedding1, embedding2, sample_ids)
for embedding1, embedding2 in itertools.permutations(self.embeddings, r=2)
]
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool:
with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool:
diffs = pool.map(_difference_for_pair, args)
return diffs

Expand Down Expand Up @@ -496,7 +497,7 @@ def get_sample_support_values(self, threads: Optional[int] = None) -> pd.Series:
)

args = [(embedding, sample_ids_superset) for embedding in self.embeddings]
with concurrent.futures.ProcessPoolExecutor(max_workers=threads) as pool:
with concurrent.futures.ProcessPoolExecutor(max_workers=threads, mp_context=multiprocessing.get_context("spawn")) as pool:
embedding_norms = pool.map(_get_embedding_norm, args)

denominator = 2 * len(self.embeddings) * np.sum(embedding_norms, axis=0) + 1e-6
Expand Down
2 changes: 0 additions & 2 deletions pandora/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import datetime
import math
import multiprocessing
import pathlib
import sys
import textwrap
Expand Down Expand Up @@ -134,5 +133,4 @@ def main():


if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
main()

0 comments on commit 176c489

Please sign in to comment.