From 17ad3567b8c87575c7ae18c1bf9f771b12665943 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Thu, 26 Mar 2020 16:42:53 +1100 Subject: [PATCH] Fix some minor issues --- bilby/core/sampler/ptemcee.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 8c26bd4ce..578377aa1 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -352,27 +352,32 @@ class Ptemcee(MCMCSampler): samples = sampler.chain[0, :, : sampler.time, :] taus = [] for ii in range(sampler.nwalkers): + tau_ii = [] for jj, key in enumerate(self.search_parameter_keys): if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key: continue try: - taus.append( + tau_ii.append( emcee.autocorr.integrated_time( samples[ii, :, jj], c=self.autocorr_c, tol=0 )[0] ) except emcee.autocorr.AutocorrError: taus.append(np.inf) + taus.append(tau_ii) + + tau = np.max(np.mean(taus, axis=0)) # Apply multiplicitive safety factor - tau = self.safety * np.mean(taus) + tau = self.safety * tau # Store for convergence checking and plotting - self.tau_list.append(tau) + self.tau_list.append(np.mean(taus, axis=0)) self.tau_list_n.append(sampler.time) # Convert to an integer tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau + self.tau_int = tau_int if np.isnan(tau_int) or np.isinf(tau_int): print_progress( @@ -465,9 +470,10 @@ class Ptemcee(MCMCSampler): def write_current_state_and_exit(self, signum=None, frame=None): logger.warning("Run terminated with signal {}".format(signum)) - if getattr(self, "pool", None): + if getattr(self, "pool", None) or self.threads == 1: self.write_current_state(plot=False) - logger.warning("Closing pool") + if getattr(self, "pool", None): + logger.info("Closing pool") self.pool.close() sys.exit(self.exit_code) @@ -501,8 +507,10 @@ class Ptemcee(MCMCSampler): plot_tau( self.tau_list_n, self.tau_list, + self.search_parameter_keys, self.outdir, self.label, + self.tau_int, self.autocorr_tau, ) @@ -644,15 +652,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label): plt.close(fig) -def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau): +def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau): fig, ax = plt.subplots() - ax.plot(tau_list_n, tau_list, "-", color="C1") - check_tau_idx = -int(tau_list[-1] * autocorr_tau) - check_taus = tau_list[check_tau_idx:] - check_taus_n = tau_list_n[check_tau_idx:] - ax.plot(check_taus_n, check_taus, "-", color="C0") + for i, key in enumerate(search_parameter_keys): + ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key) + ax.axvline(tau_list_n[-1] - tau * autocorr_tau) ax.set_xlabel("Iteration") ax.set_ylabel(r"$\langle \tau \rangle$") + ax.legend() fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label)) plt.close(fig) -- GitLab