Skip to content

Commit

Permalink
testing new pool method for mcmc
Browse files Browse the repository at this point in the history
  • Loading branch information
IainHammond committed Jan 22, 2025
1 parent 1362178 commit 238f88c
Showing 1 changed file with 117 additions and 117 deletions.
234 changes: 117 additions & 117 deletions vip_hci/fm/negfc_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,130 +926,130 @@ def mcmc_negfc_sampling(cube, angs, psfn, initial_state, algo=pca_annulus,
algo_options, weights, transmission,
mu_sigma, sigma, force_rPA]))

if verbosity > 0:
print('emcee Ensemble sampler successful')
start = datetime.datetime.now()

# #########################################################################
# Affine Invariant MCMC run
# #########################################################################
if verbosity > 1:
print('\nStart of the MCMC run ...')
print('Step | Duration/step (sec) | Remaining Estimated Time (sec)')
if verbosity > 0:
print('emcee Ensemble sampler successful')
start = datetime.datetime.now()

for k, res in enumerate(sampler.sample(pos, iterations=nIterations)):
elapsed = (datetime.datetime.now()-start).total_seconds()
# #########################################################################
# Affine Invariant MCMC run
# #########################################################################
if verbosity > 1:
if k == 0:
q = 0.5
else:
q = 1
print('{}\t\t{:.5f}\t\t\t{:.5f}'.format(k, elapsed * q,
elapsed * (limit-k-1) * q),
flush=True)

start = datetime.datetime.now()
print('\nStart of the MCMC run ...')
print('Step | Duration/step (sec) | Remaining Estimated Time (sec)')

# ---------------------------------------------------------------------
# Store the state manually in order to handle with dynamical sized chain
# ---------------------------------------------------------------------
# Check if the size of the chain is long enough.
s = chain.shape[1]
if k+1 > s: # if not, one doubles the chain length
empty = np.zeros([nwalkers, 2*s, dim])
chain = np.concatenate((chain, empty), axis=1)
# Store the state of the chain
chain[:, k] = res[0]

# ---------------------------------------------------------------------
# If k meets the criterion, one tests the non-convergence.
# ---------------------------------------------------------------------
criterion = int(np.amin([np.ceil(itermin*(1+fraction)**geom),
lastcheck+np.floor(maxgap)]))
if k == criterion:
for k, res in enumerate(sampler.sample(pos, iterations=nIterations)):
elapsed = (datetime.datetime.now()-start).total_seconds()
if verbosity > 1:
print('\n {} convergence test in progress...'.format(conv_test))

geom += 1
lastcheck = k
if display:
show_walk_plot(chain, labels=labels)

if save and verbosity == 3:
fname = '{d}/{f}_temp_k{k}'.format(d=output_dir,
f=output_file_tmp, k=k)
data = {'chain': sampler.chain,
'lnprob': sampler.lnprobability,
'AR': sampler.acceptance_fraction}
with open(fname, 'wb') as fileSave:
pickle.dump(data, fileSave)

# We only test the rhat if we have reached the min # of steps
if (k+1) >= itermin and konvergence == np.inf:
if conv_test == 'gb':
thr0 = int(np.floor(burnin*k))
thr1 = int(np.floor((1-burnin)*k*0.25))

# We calculate the rhat for each model parameter.
for j in range(dim):
part1 = chain[:, thr0:thr0 + thr1, j].reshape(-1)
part2 = chain[:, thr0 + 3 * thr1:thr0 + 4 * thr1, j
].reshape(-1)
series = np.vstack((part1, part2))
rhat[j] = gelman_rubin(series)
if verbosity > 0:
print(' r_hat = {}'.format(rhat))
cond = rhat <= rhat_threshold
print(' r_hat <= threshold = {} \n'.format(cond), flush=True)
# We test the rhat.
if (rhat <= rhat_threshold).all():
rhat_count += 1
if rhat_count < rhat_count_threshold:
if verbosity > 0:
msg = "Gelman-Rubin test OK {}/{}"
print(msg.format(rhat_count,
rhat_count_threshold))
elif rhat_count >= rhat_count_threshold:
if k == 0:
q = 0.5
else:
q = 1
print('{}\t\t{:.5f}\t\t\t{:.5f}'.format(k, elapsed * q,
elapsed * (limit-k-1) * q),
flush=True)

start = datetime.datetime.now()

# ---------------------------------------------------------------------
# Store the state manually in order to handle with dynamical sized chain
# ---------------------------------------------------------------------
# Check if the size of the chain is long enough.
s = chain.shape[1]
if k+1 > s: # if not, one doubles the chain length
empty = np.zeros([nwalkers, 2*s, dim])
chain = np.concatenate((chain, empty), axis=1)
# Store the state of the chain
chain[:, k] = res[0]

# ---------------------------------------------------------------------
# If k meets the criterion, one tests the non-convergence.
# ---------------------------------------------------------------------
criterion = int(np.amin([np.ceil(itermin*(1+fraction)**geom),
lastcheck+np.floor(maxgap)]))
if k == criterion:
if verbosity > 1:
print('\n {} convergence test in progress...'.format(conv_test))

geom += 1
lastcheck = k
if display:
show_walk_plot(chain, labels=labels)

if save and verbosity == 3:
fname = '{d}/{f}_temp_k{k}'.format(d=output_dir,
f=output_file_tmp, k=k)
data = {'chain': sampler.chain,
'lnprob': sampler.lnprobability,
'AR': sampler.acceptance_fraction}
with open(fname, 'wb') as fileSave:
pickle.dump(data, fileSave)

# We only test the rhat if we have reached the min # of steps
if (k+1) >= itermin and konvergence == np.inf:
if conv_test == 'gb':
thr0 = int(np.floor(burnin*k))
thr1 = int(np.floor((1-burnin)*k*0.25))

# We calculate the rhat for each model parameter.
for j in range(dim):
part1 = chain[:, thr0:thr0 + thr1, j].reshape(-1)
part2 = chain[:, thr0 + 3 * thr1:thr0 + 4 * thr1, j
].reshape(-1)
series = np.vstack((part1, part2))
rhat[j] = gelman_rubin(series)
if verbosity > 0:
print(' r_hat = {}'.format(rhat))
cond = rhat <= rhat_threshold
print(' r_hat <= threshold = {} \n'.format(cond), flush=True)
# We test the rhat.
if (rhat <= rhat_threshold).all():
rhat_count += 1
if rhat_count < rhat_count_threshold:
if verbosity > 0:
msg = "Gelman-Rubin test OK {}/{}"
print(msg.format(rhat_count,
rhat_count_threshold))
elif rhat_count >= rhat_count_threshold:
if verbosity > 0:
print('... ==> convergence reached')
konvergence = k
stop = konvergence + supp
else:
rhat_count = 0
elif conv_test == 'ac':
# We calculate the auto-corr test for each model parameter.
if save:
chain_name = "TMP_test_chain{:.0f}.fits".format(k)
write_fits(output_dir+'/'+chain_name, chain[:, :k])
for j in range(dim):
rhat[j] = autocorr_test(chain[:, :k, j])
thr = 1./ac_c
if verbosity > 0:
print('Auto-corr tau/N = {}'.format(rhat))
print('tau/N <= {} = {} \n'.format(thr, rhat < thr), flush=True)
if (rhat <= thr).all():
ac_count += 1
if verbosity > 0:
print('... ==> convergence reached')
konvergence = k
stop = konvergence + supp
msg = "Auto-correlation test passed for all params!"
msg += "{}/{}".format(ac_count, ac_count_thr)
print(msg)
if ac_count >= ac_count_thr:
msg = '\n ... ==> convergence reached'
print(msg)
stop = k
else:
ac_count = 0
else:
rhat_count = 0
elif conv_test == 'ac':
# We calculate the auto-corr test for each model parameter.
raise ValueError('conv_test value not recognized')
# append the autocorrelation factor to file for easy reading
if save:
chain_name = "TMP_test_chain{:.0f}.fits".format(k)
write_fits(output_dir+'/'+chain_name, chain[:, :k])
for j in range(dim):
rhat[j] = autocorr_test(chain[:, :k, j])
thr = 1./ac_c
if verbosity > 0:
print('Auto-corr tau/N = {}'.format(rhat))
print('tau/N <= {} = {} \n'.format(thr, rhat < thr), flush=True)
if (rhat <= thr).all():
ac_count += 1
if verbosity > 0:
msg = "Auto-correlation test passed for all params!"
msg += "{}/{}".format(ac_count, ac_count_thr)
print(msg)
if ac_count >= ac_count_thr:
msg = '\n ... ==> convergence reached'
print(msg)
stop = k
else:
ac_count = 0
else:
raise ValueError('conv_test value not recognized')
# append the autocorrelation factor to file for easy reading
if save:
with open(output_dir + '/MCMC_results_tau.txt', 'a') as f:
f.write(str(rhat) + '\n')
# We have reached the maximum number of steps for our Markov chain.
if k+1 >= stop:
if verbosity > 0:
print('We break the loop because we have reached convergence')
break
with open(output_dir + '/MCMC_results_tau.txt', 'a') as f:
f.write(str(rhat) + '\n')
# We have reached the maximum number of steps for our Markov chain.
if k+1 >= stop:
if verbosity > 0:
print('We break the loop because we have reached convergence')
break

if k == nIterations-1:
if verbosity > 0:
Expand Down

0 comments on commit 238f88c

Please sign in to comment.