diff --git a/.github/workflows/test-pandora.yml b/.github/workflows/test-pandora.yml index f303dd1..29c5c1e 100644 --- a/.github/workflows/test-pandora.yml +++ b/.github/workflows/test-pandora.yml @@ -33,7 +33,7 @@ jobs: # uses: mxschmitt/action-tmate@v3 - name: Run Pandora tests run: | - PYTHONPATH=. pytest -v --color=yes + PYTHONPATH=. pytest -v -x --color=yes -k 'mds and test_bootstrap_and_embed_multiple_numpy' shell: micromamba-shell {0} Install-using-pip: diff --git a/pandora/bootstrap.py b/pandora/bootstrap.py index 138a6f7..9f52ec8 100644 --- a/pandora/bootstrap.py +++ b/pandora/bootstrap.py @@ -10,8 +10,7 @@ import signal import tempfile import time -from multiprocessing import Event, Process -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import loguru import numpy as np @@ -26,82 +25,6 @@ from pandora.embedding_comparison import BatchEmbeddingComparison from pandora.logger import fmt_message -# Event to terminate waiting and running bootstrap processes in case convergence was detected -STOP_BOOTSTRAP = Event() -# Event to pause waiting and running bootstrap processes during convergence check -# Using this event allows us to utilize all cores for the convergence check -PAUSE_BOOTSTRAP = Event() - - -def _wrapped_func(func, args, tmpfile): - """Runs the given function ``func`` with the given arguments ``args`` and dumps the results as pickle in the - tmpfile. - - If the function call raises an exception, the Exception is also dumped in the - tmpfile. - """ - try: - result = func(*args) - except Exception as e: - result = e - tmpfile.write_bytes(pickle.dumps(result)) - - -def _run_function_in_process(func, args): - """Runs the given function ``func`` with the provided arguments ``args`` in a multiprocessing.Process. - - We periodically check for a stop signal (`STOP_BOOTSTRAP`) that signals bootstrap convergence. - If this signal is set, we terminate the running process. - Returns the result of ``func(*args)`` in case the process terminates without an error. - If the underlying ``func`` call raises an exception, the exception is passed through to the caller - of this function. - """ - if STOP_BOOTSTRAP.is_set(): - # instead of starting and immediately stopping the process stop here - return - # since multiprocessing.Process provides no interface to the computed result, - # we need a Queue to store the bootstrap results in - # -> we pass this Queue as additional argument to the _wrapped_func - with tempfile.NamedTemporaryFile() as tmpfile: - tmpfile = pathlib.Path(tmpfile.name) - # open a new Process using _wrapped_func - # _wrapped_func simply calls the specified function ``func`` with the provided arguments ``args`` - # and stores the result (or Exception) as pickle dump in tmpfile - # we can't use a multiprocessing.Queue here since the resulting dataset objects are too large for the Queue - process = Process( - target=functools.partial(_wrapped_func, func, args, tmpfile), - daemon=True, - ) - process.start() - process_paused = False - - while 1: - if not process.is_alive() or STOP_BOOTSTRAP.is_set(): - # Bootstrap is done computing or Bootstrapping convergence detected, terminate the running Process - process.terminate() - break - if PAUSE_BOOTSTRAP.is_set(): - if not process_paused: - # Bootrap convergence check running requiring all provided resources, pause process - os.kill(process.pid, signal.SIGTSTP) - process_paused = True - time.sleep(0.01) - continue - if process_paused: - os.kill(process.pid, signal.SIGCONT) - process_paused = False - time.sleep(0.01) - process.join() - process.close() - if not STOP_BOOTSTRAP.is_set(): - # we can only get the result from the result_queue if the process was not terminated due to the stop signal - result = pickle.load(tmpfile.open("rb")) - if isinstance(result, Exception): - # the result_queue also stores the Exception if the underlying func call raises one - # in this case we simply re-raise the Exception to be able to properly handle it - raise result - return result - def _bootstrap_convergence_check( bootstraps: List[Union[NumpyDataset, EigenDataset]], @@ -119,18 +42,7 @@ def _bootstrap_convergence_check( raise PandoraException( f"Unrecognized embedding option {embedding}. Supported are 'pca' and 'mds'." ) - # interrupt other running bootstrap processes - PAUSE_BOOTSTRAP.set() - time.sleep(0.1) - if logger is not None: - logger.debug(fmt_message("Pausing other bootstrap processes.")) - converged = _bootstrap_converged(embeddings, threads) - # resume remaining bootstrap processes - PAUSE_BOOTSTRAP.clear() - time.sleep(0.1) - if logger is not None: - logger.debug(fmt_message("Resuming other bootstrap processes.")) - return converged + return _bootstrap_converged(embeddings, threads) def _bootstrap_converged(bootstraps: List[Embedding], threads: int): @@ -153,6 +65,201 @@ def _bootstrap_converged(bootstraps: List[Embedding], threads: int): return np.all(pairwise_relative_differences <= 0.05) +def _wrapped_func(func, args, tmpfile): + """Runs the given function ``func`` with the given arguments ``args`` and dumps the results as pickle in the + tmpfile. + + If the function call raises an exception, the Exception is also dumped in the + tmpfile. + """ + try: + result = func(*args) + except Exception as e: + result = e + tmpfile.write_bytes(pickle.dumps(result)) + + +class ProcessWrapper: + """ + TODO: Docstring + """ + + def __init__(self, func: Callable, args: Iterable[Any]): + self.func = func + self.args = args + + self.process = None + + # flags used to send signals to the running process + self.pause_execution = False + self.terminate_execution = False + self.is_paused = False + + # prevent race conditions when handling a signal + self.lock = multiprocessing.RLock() + + def run(self): + with self.lock: + if self.terminate_execution: + # received terminate signal before starting, nothing to do + return + + with tempfile.NamedTemporaryFile() as result_tmpfile: + result_tmpfile = pathlib.Path(result_tmpfile.name) + self.process = multiprocessing.Process( + target=functools.partial( + _wrapped_func, self.func, self.args, result_tmpfile + ), + daemon=True, + ) + self.process.start() + + while 1: + with self.lock: + process_complete = not self.process.is_alive() + if self.terminate_execution or process_complete: + # Process finished or termination signal sent from outside + break + elif self.pause_execution: + self._pause() + else: + self._resume() + time.sleep(0.01) + + # Terminate process and get result + self._terminate() + if process_complete: + # Only if the process was not externally terminated, can get and return the result + result = pickle.load(result_tmpfile.open("rb")) + if isinstance(result, Exception): + # if the underlying func call raises an Exception, it is also pickled in the result file + # in this case we simply re-raise the Exception to be able to properly handle it in the caller + raise result + else: + return result + + def terminate(self): + with self.lock: + self.terminate_execution = True + + def pause(self): + with self.lock: + self.pause_execution = True + + def resume(self): + with self.lock: + self.pause_execution = False + + def _pause(self): + with self.lock: + if ( + self.process is not None + and not self.terminate_execution + and not self.is_paused + ): + os.kill(self.process.pid, signal.SIGSTOP) + self.is_paused = True + + def _resume(self): + with self.lock: + if ( + self.process is not None + and not self.terminate_execution + and self.is_paused + ): + os.kill(self.process.pid, signal.SIGCONT) + self.is_paused = False + + def _terminate(self): + with self.lock: + if self.process is not None and self.process.is_alive(): + self.process.terminate() + self.process.join() + self.process.close() + self.process = None + + +class ParallelBoostrapProcessManager: + """ + TODO: Docstring + """ + + def __init__(self, func: Callable, args: Iterable[Any]): + self.processes = [ProcessWrapper(func, arg) for arg in args] + + def run( + self, + threads: int, + bootstrap_convergence_check: bool, + embedding: EmbeddingAlgorithm, + logger: Optional[loguru.Logger] = None, + ): + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: + tasks = [pool.submit(process.run) for process in self.processes] + + bootstraps = [] + finished_indices = [] + + for finished_task in concurrent.futures.as_completed(tasks): + try: + bootstrap, bootstrap_index = finished_task.result() + except Exception as e: + # terminate all running and waiting processes + self.terminate() + # cleanup the ThreadPool + pool.shutdown() + raise PandoraException( + "Something went wrong during the bootstrap computation." + ) from e + + # collect the finished bootstrap dataset + bootstraps.append(bootstrap) + # and also collect the index of the finished bootstrap: + # we need to keep track of this for cleanup in the caller + finished_indices.append(bootstrap_index) + + # perform a convergence check: + # If the user wants to do convergence check (`bootstrap_convergence_check`) + # AND if not all bootstraps already computed anyway + # AND only every max(10, threads) replicates + if ( + bootstrap_convergence_check + and len(bootstraps) < len(self.processes) + and len(bootstraps) % max(10, threads) == 0 + ): + # Pause all running/waiting processes for the convergence check + # so we can use all available threads for the parallel convergence check computation + self.pause() + converged = _bootstrap_convergence_check( + bootstraps, embedding, threads, logger + ) + self.resume() + if converged: + # in case convergence is detected, we set the event that interrupts all running processes + if logger is not None: + logger.debug( + fmt_message( + "Bootstrap convergence detected. Stopping bootstrapping." + ) + ) + self.terminate() + break + + return bootstraps, finished_indices + + def pause(self): + for process in self.processes: + process.pause() + + def resume(self): + for process in self.processes: + process.resume() + + def terminate(self): + for process in self.processes: + process.terminate() + + def _bootstrap_and_embed( bootstrap_index, dataset, @@ -285,10 +392,6 @@ def bootstrap_and_embed_multiple( the 48 replicates are already computed anyway, might as well use them instead of throwing away 30 in case 10 would have been sufficient). """ - # before starting the bootstrap computation, make sure the convergence and pause signals are cleared - STOP_BOOTSTRAP.clear() - PAUSE_BOOTSTRAP.clear() - if threads is None: threads = multiprocessing.cpu_count() @@ -326,61 +429,15 @@ def bootstrap_and_embed_multiple( ) ) - with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: - tasks = [ - pool.submit(_run_function_in_process, _bootstrap_and_embed, args) - for args in bootstrap_args - ] - bootstraps = [] - finished_indices = [] - - if logger is not None: - logger.debug(fmt_message("Starting Bootstrap computation.")) - - for finished_task in concurrent.futures.as_completed(tasks): - try: - bootstrap_dataset, bootstrap_index = finished_task.result() - if logger is not None: - logger.debug( - fmt_message(f"Finished computing bootstrap #{bootstrap_index}.") - ) - except Exception as e: - STOP_BOOTSTRAP.set() - time.sleep(0.1) - pool.shutdown() - raise PandoraException( - "Something went wrong during the bootstrap computation." - ) from e - - # collect the finished bootstrap dataset - bootstraps.append(bootstrap_dataset) - # and also collect the index of the finished bootstrap: - # we need to keep track of this for file cleanup later on - finished_indices.append(bootstrap_index) - - # perform a convergence check: - # If the user wants to do convergence check (`bootstrap_convergence_check`) - # AND if not all n_bootstraps already computed anyway - # AND only every max(10, threads) replicates - if ( - bootstrap_convergence_check - and len(bootstraps) < n_bootstraps - and len(bootstraps) % max(10, threads) == 0 - ): - if _bootstrap_convergence_check(bootstraps, embedding, threads, logger): - # in case convergence is detected, we set the event that interrupts all running smartpca runs - if logger is not None: - logger.debug( - fmt_message( - "Bootstrap convergence detected. Stopping bootstrapping." - ) - ) - STOP_BOOTSTRAP.set() - break - - # reset the convergence and pause flag - STOP_BOOTSTRAP.clear() - PAUSE_BOOTSTRAP.clear() + parallel_bootstrap_process_manager = ParallelBoostrapProcessManager( + _bootstrap_and_embed, bootstrap_args + ) + bootstraps, finished_indices = parallel_bootstrap_process_manager.run( + threads=threads, + bootstrap_convergence_check=bootstrap_convergence_check, + embedding=embedding, + logger=logger, + ) # we also need to remove all files associated with the interrupted bootstrap calls for bootstrap_index in range(n_bootstraps): @@ -413,7 +470,7 @@ def _bootstrap_and_embed_numpy( f"Unrecognized embedding option {embedding}. Supported are 'pca' and 'mds'." ) - return bootstrap + return bootstrap, seed def bootstrap_and_embed_multiple_numpy( @@ -498,10 +555,6 @@ def bootstrap_and_embed_multiple_numpy( the 48 replicates are already computed anyway, might as well use them instead of throwing away 30 in case 10 would have been sufficient). """ - # before starting the bootstrap computation, make sure the convergence and pause signals are cleared - STOP_BOOTSTRAP.clear() - PAUSE_BOOTSTRAP.clear() - if threads is None: threads = multiprocessing.cpu_count() @@ -519,40 +572,14 @@ def bootstrap_and_embed_multiple_numpy( ) for _ in range(n_bootstraps) ] - with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as pool: - tasks = [ - pool.submit(_run_function_in_process, _bootstrap_and_embed_numpy, args) - for args in bootstrap_args - ] - bootstraps = [] - - for finished_task in concurrent.futures.as_completed(tasks): - try: - bootstrap = finished_task.result() - bootstraps.append(bootstrap) - except Exception as e: - STOP_BOOTSTRAP.set() - time.sleep(0.1) - pool.shutdown() - raise PandoraException( - "Something went wrong during the bootstrap computation." - ) from e - - # perform a convergence check: - # If the user wants to do convergence check (`bootstrap_convergence_check`) - # AND if not all n_bootstraps already computed anyway - # AND only every max(10, threads) replicates - if ( - bootstrap_convergence_check - and len(bootstraps) < n_bootstraps - and len(bootstraps) % max(10, threads) == 0 - ): - if _bootstrap_convergence_check(bootstraps, embedding, threads): - # in case convergence is detected, we set the event that interrupts all running bootstrap runs - STOP_BOOTSTRAP.set() - break - - # reset the convergence and pause flag - STOP_BOOTSTRAP.clear() - PAUSE_BOOTSTRAP.clear() + + parallel_bootstrap_process_manager = ParallelBoostrapProcessManager( + _bootstrap_and_embed_numpy, bootstrap_args + ) + bootstraps, _ = parallel_bootstrap_process_manager.run( + threads=threads, + bootstrap_convergence_check=bootstrap_convergence_check, + embedding=embedding, + ) + return bootstraps diff --git a/pandora/dataset.py b/pandora/dataset.py index a782c54..4fd1227 100644 --- a/pandora/dataset.py +++ b/pandora/dataset.py @@ -788,9 +788,6 @@ def bootstrap( EigenDataset A new dataset object containing the bootstrap replicate data. """ - from pandora.logger import logger - - logger.debug(f"computing bootstrap dataset: {bootstrap_prefix}") bs_ind_file = pathlib.Path(f"{bootstrap_prefix}.ind") bs_geno_file = pathlib.Path(f"{bootstrap_prefix}.geno") bs_snp_file = pathlib.Path(f"{bootstrap_prefix}.snp") @@ -800,7 +797,6 @@ def bootstrap( ) if files_exist and not redo: - logger.debug(f"bootstrap already exists {bootstrap_prefix}") return EigenDataset( bootstrap_prefix, self._embedding_populations_file, @@ -863,7 +859,6 @@ def bootstrap( # when bootstrapping on SNP level, the .ind file does not change shutil.copy(self._ind_file, bs_ind_file) - logger.debug(f"bootstrap computation done: {bootstrap_prefix}") return EigenDataset( bootstrap_prefix, self._embedding_populations_file, diff --git a/tests/conftest.py b/tests/conftest.py index 2f7ec12..2557e60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ import pandas as pd import pytest -from pandora.bootstrap import PAUSE_BOOTSTRAP, STOP_BOOTSTRAP from pandora.custom_types import Executable from pandora.dataset import EigenDataset, NumpyDataset from pandora.embedding import PCA, from_smartpca @@ -73,10 +72,3 @@ def cleanup_pandora_test_results(): results_dir = pathlib.Path("tests") / "data" / "results" if results_dir.exists(): shutil.rmtree(results_dir) - - -@pytest.fixture(autouse=True) -def cleanup_bootstrap_signals(): - # reset the bootstrap stop and pause signal after each test to make sure tests don't influence each other - STOP_BOOTSTRAP.clear() - PAUSE_BOOTSTRAP.clear() diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index 4fb999a..90f3ce5 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -9,10 +9,10 @@ import pytest from pandora.bootstrap import ( - STOP_BOOTSTRAP, + ParallelBoostrapProcessManager, + ProcessWrapper, _bootstrap_converged, _bootstrap_convergence_check, - _run_function_in_process, _wrapped_func, bootstrap_and_embed_multiple, bootstrap_and_embed_multiple_numpy, @@ -29,17 +29,152 @@ def _dummy_func(status: int): raise ValueError("Status < 0") -def _dummy_func_with_wait(status: int): - """Dummy function that returns the status if it is > 0 and otherwise raises a ValueError. +def _dummy_func_with_wait(wait_seconds: int): + """Dummy function that returns the wait_seconds if it is > 0 and otherwise raises a ValueError. - Compared to _dummy_func it however waits for 1s before returning/raising. + Compared to _dummy_func it however sleeps for ``wait_seconds`` before returning/raising. """ - time.sleep(1) - if status >= 0: - return status + # since sleep is based on the CPU time, we can only test the process.pause() function if we do this hack + total_wait = 0 + while total_wait < wait_seconds: + # print(total_wait) + time.sleep(0.001) + total_wait += 0.001 + if wait_seconds >= 0: + return wait_seconds raise ValueError("Status < 0") +class TestProcessWrapper: + def test_run(self): + wait_duration = 1 + proc = ProcessWrapper(_dummy_func_with_wait, [wait_duration]) + # without any external signals, this should simply run and return the set wait_duration + result = proc.run() + assert result == wait_duration + + def test_run_with_terminate_signal(self): + # setting the terminate signal of the process should prevent the process from starting and return None + proc = ProcessWrapper(_dummy_func_with_wait, [1]) + proc.terminate_execution = True + res = proc.run() + assert res is None + + def test_run_with_terminate_signal_set_during_execution(self): + # we create five processes using concurrent futures but set the stop signal once three are completed + processes = [ProcessWrapper(_dummy_func_with_wait, [i]) for i in range(50)] + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + tasks = [pool.submit(process.run) for process in processes] + finished_ct = 0 + cancelled_ct = 0 + + for finished_task in concurrent.futures.as_completed(tasks): + result = finished_task.result() + if result is None: + cancelled_ct += 1 + continue + finished_ct += 1 + assert ( + result >= 0 + ) # every result should be an int >= 0, no error should be raised + + if finished_ct >= 3: + # send the stop signal to all remaining processes + for process in processes: + process.terminate() + + # we should only have three results, but during the termination some other results might finish + # so we allow for some margin + assert finished_ct < 10 + assert cancelled_ct >= 40 + # also all processes should be None + assert all(p.process is None for p in processes) + + def test_run_with_exception_during_execution(self): + # we create five processes using concurrent futures but one of them raises a ValueError + # we catch this error and transform it to a Pandora exception to make sure catching the error works as expected + with pytest.raises(PandoraException): + processes = [ + ProcessWrapper(_dummy_func_with_wait, [i]) for i in range(-1, 4) + ] + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + tasks = [pool.submit(process.run) for process in processes] + for finished_task in concurrent.futures.as_completed(tasks): + try: + finished_task.result() + except ValueError: + # we specifically catch only the ValueError raised by the _dummy_func_with_wait + # there should be no other errors + raise PandoraException() + + def test_pause_and_resume(self): + """Since _dummy_func_with_wait only waits for the passed duration of seconds X and then returns, it is + reasonable to assume that it's runtime will be about X seconds in total. + + If we pause the process during this wait time for 3 seconds, the total runtime + of the function call should be >= 3 + X seconds. + """ + process = ProcessWrapper(_dummy_func_with_wait, [2]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + start_time = time.perf_counter() + future = pool.submit(process.run) + process.pause() + time.sleep(3) + process.resume() + # wait for the process to finish + future.result() + end_time = time.perf_counter() + assert end_time - start_time >= 5 + + def test_pause_and_terminate(self): + # we should be able to terminate a paused process without any error + process = ProcessWrapper(_dummy_func_with_wait, [2]) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + future = pool.submit(process.run) + process.pause() + process.terminate() + result = future.result() + assert result is None + assert process.process is None + + +class TestParallelBoostrapProcessManager: + def test_terminate(self): + # Note: this test is similar to one above, but in this case we explicitly test the termination implementation + # of the ParallelBoostrapProcessManager + process_manager = ParallelBoostrapProcessManager( + _dummy_func_with_wait, [[i] for i in range(50)] + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + tasks = [pool.submit(process.run) for process in process_manager.processes] + + finished_ct = 0 + cancelled_ct = 0 + for finished_task in concurrent.futures.as_completed(tasks): + # once 3 tasks are done, send the pause signal + result = finished_task.result() + if result is None: + cancelled_ct += 1 + continue + finished_ct += 1 + + if finished_ct == 3: + # terminate bootstraps + process_manager.terminate() + + # we should only have three results, but during the termination some other results might finish + # so we allow for some margin + assert finished_ct < 10 + assert cancelled_ct >= 40 + # also all processes should be None + assert all(p.process is None for p in process_manager.processes) + + def test_wrapped_func(): with tempfile.NamedTemporaryFile() as tmpfile: tmpfile = pathlib.Path(tmpfile.name) @@ -61,67 +196,6 @@ def test_wrapped_func(): assert str(status) == "Status < 0" -def test_run_function_in_progress_stop_signal_set(): - # setting the stop signal before calling _run_function_in_progress should return None - STOP_BOOTSTRAP.set() - status = _run_function_in_process(_dummy_func, [1]) - assert status is None - - -def test_run_function_in_progress_stop_signal_unset(): - # if the stop signal is not set, calling _run_function_in_progress with _dummy_func and [1] should return 1 - STOP_BOOTSTRAP.clear() - status = _run_function_in_process(_dummy_func, [1]) - assert status == 1 - - -def test_run_function_in_progress_stop_signal_set_during_execution(): - # we create five processes using concurrent futures but set the stop signal once three are completed - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - tasks = [ - pool.submit(_run_function_in_process, _dummy_func_with_wait, [status]) - for status in range(50) - ] - finished_ct = 0 - cancelled_results = 0 - - for finished_task in concurrent.futures.as_completed(tasks): - result = finished_task.result() - if result is None: - cancelled_results += 1 - continue - finished_ct += 1 - assert ( - result >= 0 - ) # every result should be an int >= 0, no error should be raised - - if finished_ct >= 3: - # send the stop signal to all remaining processes - STOP_BOOTSTRAP.set() - - # we should only have three results, but - assert finished_ct < 10 - assert cancelled_results >= 40 - - -def test_run_function_in_progress_exception_during_execution(): - # we create five processes using concurrent futures but one of them raises a ValueError - # we catch this error and transform it to a Pandora exception to make sure catching the error works as expected - with pytest.raises(PandoraException): - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - tasks = [ - pool.submit(_run_function_in_process, _dummy_func_with_wait, [status]) - for status in range(-1, 4) - ] - for finished_task in concurrent.futures.as_completed(tasks): - try: - finished_task.result() - except ValueError: - # we specifically catch only the ValueError raised by the _dummy_func_with_wait - # there should be no other errors - raise PandoraException() - - @pytest.mark.parametrize( "embedding_algorithm", [EmbeddingAlgorithm.PCA, EmbeddingAlgorithm.MDS] )