From 17ad3567b8c87575c7ae18c1bf9f771b12665943 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Thu, 26 Mar 2020 16:42:53 +1100
Subject: [PATCH] Fix some minor issues

---
 bilby/core/sampler/ptemcee.py | 29 ++++++++++++++++++-----------
 1 file changed, 18 insertions(+), 11 deletions(-)

diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 8c26bd4ce..578377aa1 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -352,27 +352,32 @@ class Ptemcee(MCMCSampler):
             samples = sampler.chain[0, :, : sampler.time, :]
             taus = []
             for ii in range(sampler.nwalkers):
+                tau_ii = []
                 for jj, key in enumerate(self.search_parameter_keys):
                     if self.ignore_keys_for_tau and self.ignore_keys_for_tau in key:
                         continue
                     try:
-                        taus.append(
+                        tau_ii.append(
                             emcee.autocorr.integrated_time(
                                 samples[ii, :, jj], c=self.autocorr_c, tol=0
                             )[0]
                         )
                     except emcee.autocorr.AutocorrError:
                         taus.append(np.inf)
+                taus.append(tau_ii)
+
+            tau = np.max(np.mean(taus, axis=0))
 
             # Apply multiplicitive safety factor
-            tau = self.safety * np.mean(taus)
+            tau = self.safety * tau
 
             # Store for convergence checking and plotting
-            self.tau_list.append(tau)
+            self.tau_list.append(np.mean(taus, axis=0))
             self.tau_list_n.append(sampler.time)
 
             # Convert to an integer
             tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
+            self.tau_int = tau_int
 
             if np.isnan(tau_int) or np.isinf(tau_int):
                 print_progress(
@@ -465,9 +470,10 @@ class Ptemcee(MCMCSampler):
 
     def write_current_state_and_exit(self, signum=None, frame=None):
         logger.warning("Run terminated with signal {}".format(signum))
-        if getattr(self, "pool", None):
+        if getattr(self, "pool", None) or self.threads == 1:
             self.write_current_state(plot=False)
-            logger.warning("Closing pool")
+        if getattr(self, "pool", None):
+            logger.info("Closing pool")
             self.pool.close()
         sys.exit(self.exit_code)
 
@@ -501,8 +507,10 @@ class Ptemcee(MCMCSampler):
             plot_tau(
                 self.tau_list_n,
                 self.tau_list,
+                self.search_parameter_keys,
                 self.outdir,
                 self.label,
+                self.tau_int,
                 self.autocorr_tau,
             )
 
@@ -644,15 +652,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
     plt.close(fig)
 
 
-def plot_tau(tau_list_n, tau_list, outdir, label, autocorr_tau):
+def plot_tau(tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau):
     fig, ax = plt.subplots()
-    ax.plot(tau_list_n, tau_list, "-", color="C1")
-    check_tau_idx = -int(tau_list[-1] * autocorr_tau)
-    check_taus = tau_list[check_tau_idx:]
-    check_taus_n = tau_list_n[check_tau_idx:]
-    ax.plot(check_taus_n, check_taus, "-", color="C0")
+    for i, key in enumerate(search_parameter_keys):
+        ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
+    ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
     ax.set_xlabel("Iteration")
     ax.set_ylabel(r"$\langle \tau \rangle$")
+    ax.legend()
     fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
     plt.close(fig)
 
-- 
GitLab