Skip to content
Snippets Groups Projects
Commit 17ad3567 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Fix some minor issues

parent 486ad46f
No related branches found
No related tags found
1 merge request!750Improve ptemcee
Pipeline #113497 failed
......@@ -352,27 +352,32 @@ class Ptemcee(MCMCSampler):
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
tau_ii = []
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
tau_ii.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
taus.append(tau_ii)
tau = np.max(np.mean(taus, axis=0))
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
tau = self.safety * tau
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list.append(np.mean(taus, axis=0))
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
self.tau_int = tau_int
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
......@@ -465,9 +470,10 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, "pool", None):
if getattr(self, "pool", None) or self.threads == 1:
self.write_current_state(plot=False)
logger.warning("Closing pool")
if getattr(self, "pool", None):
logger.info("Closing pool")
self.pool.close()
sys.exit(self.exit_code)
......@@ -501,8 +507,10 @@ class Ptemcee(MCMCSampler):
plot_tau(
self.tau_list_n,
self.tau_list,
self.search_parameter_keys,
self.outdir,
self.label,
self.tau_int,
self.autocorr_tau,
)
......@@ -644,15 +652,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color="C1")
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color="C0")
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment