Skip to content

Commit

Permalink
added multiprocessing.set_start_method to safely determine method
Browse files Browse the repository at this point in the history
  • Loading branch information
IainHammond committed Jan 23, 2025
1 parent 35b0cb9 commit f523377
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions vip_hci/fm/negfc_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
import numpy as np
import os
import emcee
from multiprocessing import cpu_count, Pool, set_start_method
import multiprocessing
import inspect
import datetime
import corner
Expand Down Expand Up @@ -814,8 +814,8 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
else:
raise TypeError("Interpolation not recognized.")

if nproc is None:
nproc = cpu_count() // 2 # Hyper-threading doubles the # of cores
if nproc is None: # actually not used anymore, now that Pool has been implemented for EnsembleSampler
nproc = multiprocessing.cpu_count()

# #########################################################################
# Initialization of the variables
Expand Down Expand Up @@ -914,8 +914,13 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"

set_start_method("forkserver", force=True)
with Pool() as pool:
avail_methods = multiprocessing.get_all_start_methods()
if "forkserver" in avail_methods:
multiprocessing.set_start_method("forkserver", force=True) # faster, better
else:
multiprocessing.set_start_method("spawn", force=True) # slower, but on all platforms

with multiprocessing.Pool() as pool:
sampler = emcee.EnsembleSampler(nwalkers, dim, lnprob,
pool=pool, moves=emcee.moves.StretchMove(a=2),
args=([bounds, cube, angs, psfn,
Expand Down Expand Up @@ -1078,7 +1083,7 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
timing(start_time)

# reactivate multithreading
ncpus = cpu_count()
ncpus = multiprocessing.cpu_count()
os.environ["MKL_NUM_THREADS"] = str(ncpus)
os.environ["NUMEXPR_NUM_THREADS"] = str(ncpus)
os.environ["OMP_NUM_THREADS"] = str(ncpus)
Expand Down

0 comments on commit f523377

Please sign in to comment.