Skip to content
Snippets Groups Projects

Improve ptemcee

Merged Gregory Ashton requested to merge improve-ptemcee into master
All threads resolved!
1 file
+ 18
11
Compare changes
  • Side-by-side
  • Inline
@@ -352,27 +352,32 @@ class Ptemcee(MCMCSampler):
samples = sampler.chain[0, :, : sampler.time, :]
taus = []
for ii in range(sampler.nwalkers):
tau_ii = []
for jj, key in enumerate(self.search_parameter_keys):
if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
continue
try:
taus.append(
tau_ii.append(
emcee.autocorr.integrated_time(
samples[ii, :, jj], c=self.autocorr_c, tol=0
)[0]
)
except emcee.autocorr.AutocorrError:
taus.append(np.inf)
taus.append(tau_ii)
tau = np.max(np.mean(taus, axis=0))
# Apply multiplicitive safety factor
tau = self.safety * np.mean(taus)
tau = self.safety * tau
# Store for convergence checking and plotting
self.tau_list.append(tau)
self.tau_list.append(np.mean(taus, axis=0))
self.tau_list_n.append(sampler.time)
# Convert to an integer
tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
self.tau_int = tau_int
if np.isnan(tau_int) or np.isinf(tau_int):
print_progress(
@@ -465,9 +470,10 @@ class Ptemcee(MCMCSampler):
def write_current_state_and_exit(self, signum=None, frame=None):
logger.warning("Run terminated with signal {}".format(signum))
if getattr(self, "pool", None):
if getattr(self, "pool", None) or self.threads == 1:
self.write_current_state(plot=False)
logger.warning("Closing pool")
if getattr(self, "pool", None):
logger.info("Closing pool")
self.pool.close()
sys.exit(self.exit_code)
@@ -501,8 +507,10 @@ class Ptemcee(MCMCSampler):
plot_tau(
self.tau_list_n,
self.tau_list,
self.search_parameter_keys,
self.outdir,
self.label,
self.tau_int,
self.autocorr_tau,
)
@@ -644,15 +652,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
plt.close(fig)
def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau):
fig, ax = plt.subplots()
ax.plot(tau_list_n, tau_list, "-", color="C1")
check_tau_idx = -int(tau_list[-1] * autocorr_tau)
check_taus = tau_list[check_tau_idx:]
check_taus_n = tau_list_n[check_tau_idx:]
ax.plot(check_taus_n, check_taus, "-", color="C0")
for i, key in enumerate(search_parameter_keys):
ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
ax.set_xlabel("Iteration")
ax.set_ylabel(r"$\langle \tau \rangle$")
ax.legend()
fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
plt.close(fig)
Loading