From 5629325dcfc4514820f1151bd1bab29e1bb15acc Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 21 Dec 2020 04:21:21 -0600
Subject: [PATCH] Add a mean-log-likelihood method to improve the ACT
 estimation

---
 bilby/core/sampler/ptemcee.py     | 501 +++++++++++++++++++++++-------
 test/core/sampler/ptemcee_test.py |  14 +-
 2 files changed, 394 insertions(+), 121 deletions(-)

diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 1ee756c40..4dc97a4aa 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -8,10 +8,12 @@ import sys
 import time
 import dill
 from collections import namedtuple
+import logging
 
 import numpy as np
 import pandas as pd
 import matplotlib.pyplot as plt
+import scipy.signal
 
 from ..utils import logger, check_directory_exists_and_if_not_mkdir
 from .base_sampler import SamplerError, MCMCSampler
@@ -23,10 +25,14 @@ ConvergenceInputs = namedtuple(
         "autocorr_c",
         "autocorr_tol",
         "autocorr_tau",
+        "gradient_tau",
+        "gradient_mean_log_posterior",
+        "Q_tol",
         "safety",
         "burn_in_nact",
+        "burn_in_fixed_discard",
+        "mean_logl_frac",
         "thin_by_nact",
-        "frac_threshold",
         "nsamples",
         "ignore_keys_for_tau",
         "min_tau",
@@ -53,6 +59,10 @@ class Ptemcee(MCMCSampler):
         The number of burn-in autocorrelation times to discard and the thin-by
         factor. Increasing burn_in_nact increases the time required for burn-in.
         Increasing thin_by_nact increases the time required to obtain nsamples.
+    burn_in_fixed_discard: int (0)
+        A fixed number of samples to discard for burn-in
+    mean_logl_frac: float, (0.0.1)
+        The maximum fractional change the mean log-likelihood to accept
     autocorr_tol: int, (50)
         The minimum number of autocorrelation times needed to trust the
         estimate of the autocorrelation time.
@@ -62,14 +72,18 @@ class Ptemcee(MCMCSampler):
         A multiplicitive factor for the estimated autocorrelation. Useful for
         cases where non-convergence can be observed by eye but the automated
         tools are failing.
-    autocorr_tau:
+    autocorr_tau: int, (1)
         The number of autocorrelation times to use in assessing if the
         autocorrelation time is stable.
-    frac_threshold: float, (0.01)
-        The maximum fractional change in the autocorrelation for the last
-        autocorr_tau steps. If the fractional change exceeds this value,
-        sampling will continue until the estimate of the autocorrelation time
-        can be trusted.
+    gradient_tau: float, (0.1)
+        The maximum (smoothed) local gradient of the ACT estimate to allow.
+        This ensures the ACT estimate is stable before finishing sampling.
+    gradient_mean_log_posterior: float, (0.1)
+        The maximum (smoothed) local gradient of the logliklilhood to allow.
+        This ensures the ACT estimate is stable before finishing sampling.
+    Q_tol: float (1.01)
+        The maximum between-chain to within-chain tolerance allowed (akin to
+        the Gelman-Rubin statistic).
     min_tau: int, (1)
         A minimum tau (autocorrelation time) to accept.
     check_point_deltaT: float, (600)
@@ -79,7 +93,7 @@ class Ptemcee(MCMCSampler):
     exit_code: int, (77)
         The code on which the sampler exits.
     store_walkers: bool (False)
-        If true, store the unthinned, unburnt chaines in the result. Note, this
+        If true, store the unthinned, unburnt chains in the result. Note, this
         is not recommended for cases where tau is large.
     ignore_keys_for_tau: str
         A pattern used to ignore keys in estimating the autocorrelation time.
@@ -90,6 +104,12 @@ class Ptemcee(MCMCSampler):
         The walkers are then initialized from the range of values obtained.
         If a list, for the keys in the list the optimization step is applied,
         otherwise the initial points are drawn from the prior.
+    niterations_per_check: int (5)
+        The number of iteration steps to take before checking ACT. This
+        effectively pre-thins the chains. Larger values reduce the per-eval
+        timing due to improved efficiency. But, if it is made too large the
+        pre-thinning may be overly agressive effectively wasting compute-time.
+        If you see tau=1, then niterations_per_check is likely too large.
 
 
     Other Parameters
@@ -98,7 +118,7 @@ class Ptemcee(MCMCSampler):
         The number of walkers
     nsteps: int, (100)
         The number of steps to take
-    ntemps: int (2)
+    ntemps: int (10)
         The number of temperatures used by ptemcee
     Tmax: float
         The maximum temperature
@@ -107,15 +127,15 @@ class Ptemcee(MCMCSampler):
 
     # Arguments used by ptemcee
     default_kwargs = dict(
-        ntemps=20,
-        nwalkers=200,
+        ntemps=10,
+        nwalkers=100,
         Tmax=None,
         betas=None,
         a=2.0,
         adaptation_lag=10000,
         adaptation_time=100,
         random=None,
-        adapt=True,
+        adapt=False,
         swap_ratios=False,
     )
 
@@ -130,13 +150,17 @@ class Ptemcee(MCMCSampler):
         skip_import_verification=False,
         resume=True,
         nsamples=5000,
-        burn_in_nact=10,
+        burn_in_nact=50,
+        burn_in_fixed_discard=0,
+        mean_logl_frac=0.01,
         thin_by_nact=0.5,
         autocorr_tol=50,
         autocorr_c=5,
         safety=1,
-        autocorr_tau=5,
-        frac_threshold=0.01,
+        autocorr_tau=1,
+        gradient_tau=0.1,
+        gradient_mean_log_posterior=0.1,
+        Q_tol=1.02,
         min_tau=1,
         check_point_deltaT=600,
         threads=1,
@@ -145,7 +169,8 @@ class Ptemcee(MCMCSampler):
         store_walkers=False,
         ignore_keys_for_tau=None,
         pos0="prior",
-        niterations_per_check=10,
+        niterations_per_check=5,
+        log10beta_min=None,
         **kwargs
     ):
         super(Ptemcee, self).__init__(
@@ -184,14 +209,19 @@ class Ptemcee(MCMCSampler):
             autocorr_tau=autocorr_tau,
             safety=safety,
             burn_in_nact=burn_in_nact,
+            burn_in_fixed_discard=burn_in_fixed_discard,
+            mean_logl_frac=mean_logl_frac,
             thin_by_nact=thin_by_nact,
-            frac_threshold=frac_threshold,
+            gradient_tau=gradient_tau,
+            gradient_mean_log_posterior=gradient_mean_log_posterior,
+            Q_tol=Q_tol,
             nsamples=nsamples,
             ignore_keys_for_tau=ignore_keys_for_tau,
             min_tau=min_tau,
             niterations_per_check=niterations_per_check,
         )
         self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
+        logger.info("Using convergence inputs: {}".format(self.convergence_inputs))
 
         # Check if threads was given as an equivalent arg
         if threads == 1:
@@ -206,6 +236,23 @@ class Ptemcee(MCMCSampler):
         self.store_walkers = store_walkers
         self.pos0 = pos0
 
+        self._periodic = [
+            self.priors[key].boundary == "periodic" for key in self.search_parameter_keys
+        ]
+        self.priors.sample()
+        self._minima = np.array([
+            self.priors[key].minimum for key in self.search_parameter_keys
+        ])
+        self._range = np.array([
+            self.priors[key].maximum for key in self.search_parameter_keys
+        ]) - self._minima
+
+        self.log10beta_min = log10beta_min
+        if self.log10beta_min is not None:
+            betas = np.logspace(0, self.log10beta_min, self.ntemps)
+            logger.warning("Using betas {}".format(betas))
+            self.kwargs["betas"] = betas
+
     @property
     def sampler_function_kwargs(self):
         """ Kwargs passed to samper.sampler() """
@@ -322,7 +369,7 @@ class Ptemcee(MCMCSampler):
         return pos0
 
     def setup_sampler(self):
-        """ Either initialize the sampelr or read in the resume file """
+        """ Either initialize the sampler or read in the resume file """
         import ptemcee
 
         if os.path.isfile(self.resume_file) and self.resume is True:
@@ -335,11 +382,13 @@ class Ptemcee(MCMCSampler):
             self.iteration = data["iteration"]
             self.chain_array = data["chain_array"]
             self.log_likelihood_array = data["log_likelihood_array"]
+            self.log_posterior_array = data["log_posterior_array"]
             self.pos0 = data["pos0"]
             self.beta_list = data["beta_list"]
             self.sampler._betas = np.array(self.beta_list[-1])
             self.tau_list = data["tau_list"]
             self.tau_list_n = data["tau_list_n"]
+            self.Q_list = data["Q_list"]
             self.time_per_check = data["time_per_check"]
 
             # Initialize the pool
@@ -376,10 +425,12 @@ class Ptemcee(MCMCSampler):
             # Initialize storing results
             self.iteration = 0
             self.chain_array = self.get_zero_chain_array()
-            self.log_likelihood_array = self.get_zero_log_likelihood_array()
+            self.log_likelihood_array = self.get_zero_array()
+            self.log_posterior_array = self.get_zero_array()
             self.beta_list = []
             self.tau_list = []
             self.tau_list_n = []
+            self.Q_list = []
             self.time_per_check = []
             self.pos0 = self.get_pos0()
 
@@ -388,7 +439,7 @@ class Ptemcee(MCMCSampler):
     def get_zero_chain_array(self):
         return np.zeros((self.nwalkers, self.max_steps, self.ndim))
 
-    def get_zero_log_likelihood_array(self):
+    def get_zero_array(self):
         return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
 
     def get_pos0(self):
@@ -425,18 +476,27 @@ class Ptemcee(MCMCSampler):
                     self.pos0, storechain=False,
                     iterations=self.convergence_inputs.niterations_per_check,
                     **self.sampler_function_kwargs):
-                pass
+                pos0[:, :, self._periodic] = np.mod(
+                    pos0[:, :, self._periodic] - self._minima[self._periodic],
+                    self._range[self._periodic]
+                ) + self._minima[self._periodic]
 
             if self.iteration == self.chain_array.shape[1]:
                 self.chain_array = np.concatenate((
                     self.chain_array, self.get_zero_chain_array()), axis=1)
                 self.log_likelihood_array = np.concatenate((
-                    self.log_likelihood_array, self.get_zero_log_likelihood_array()),
+                    self.log_likelihood_array, self.get_zero_array()),
+                    axis=2)
+                self.log_posterior_array = np.concatenate((
+                    self.log_posterior_array, self.get_zero_array()),
                     axis=2)
 
             self.pos0 = pos0
             self.chain_array[:, self.iteration, :] = pos0[0, :, :]
             self.log_likelihood_array[:, :, self.iteration] = log_likelihood
+            self.log_posterior_array[:, :, self.iteration] = log_posterior
+            self.mean_log_posterior = np.mean(
+                self.log_posterior_array[:, :, :self. iteration], axis=1)
 
             # Calculate time per iteration
             self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
@@ -444,6 +504,28 @@ class Ptemcee(MCMCSampler):
 
             self.iteration += 1
 
+            # Calculate minimum iteration step to discard
+            minimum_iteration = get_minimum_stable_itertion(
+                self.mean_log_posterior,
+                frac=self.convergence_inputs.mean_logl_frac
+            )
+            logger.debug("Minimum iteration = {}".format(minimum_iteration))
+
+            # Calculate the maximum discard number
+            discard_max = np.max(
+                [self.convergence_inputs.burn_in_fixed_discard,
+                 minimum_iteration]
+            )
+
+            if self.iteration > discard_max + self.nwalkers:
+                # If we have taken more than nwalkers steps after the discard
+                # then set the discard
+                self.discard = discard_max
+            else:
+                # If haven't discard everything (avoid initialisation bias)
+                logger.debug("Too few steps to calculate convergence")
+                self.discard = self.iteration
+
             (
                 stop,
                 self.nburn,
@@ -451,7 +533,8 @@ class Ptemcee(MCMCSampler):
                 self.tau_int,
                 self.nsamples_effective,
             ) = check_iteration(
-                self.chain_array[:, :self.iteration + 1, :],
+                self.iteration,
+                self.chain_array[:, self.discard:self.iteration, :],
                 sampler,
                 self.convergence_inputs,
                 self.search_parameter_keys,
@@ -459,6 +542,8 @@ class Ptemcee(MCMCSampler):
                 self.beta_list,
                 self.tau_list,
                 self.tau_list_n,
+                self.Q_list,
+                self.mean_log_posterior,
             )
 
             if stop:
@@ -479,19 +564,21 @@ class Ptemcee(MCMCSampler):
 
         # Get 0-likelihood samples and store in the result
         self.result.samples = self.chain_array[
-            :, self.nburn : self.iteration : self.thin, :
+            :, self.discard + self.nburn : self.iteration : self.thin, :
         ].reshape((-1, self.ndim))
         loglikelihood = self.log_likelihood_array[
-            0, :, self.nburn : self.iteration : self.thin
+            0, :, self.discard + self.nburn : self.iteration : self.thin
         ]  # nwalkers, nsteps
         self.result.log_likelihood_evaluations = loglikelihood.reshape((-1))
 
         if self.store_walkers:
             self.result.walkers = self.sampler.chain
         self.result.nburn = self.nburn
+        self.result.discard = self.discard
 
         log_evidence, log_evidence_err = compute_evidence(
-            sampler, self.log_likelihood_array, self.outdir, self.label, self.nburn,
+            sampler, self.log_likelihood_array, self.outdir,
+            self.label, self.discard, self.nburn,
             self.thin, self.iteration,
         )
         self.result.log_evidence = log_evidence
@@ -524,43 +611,79 @@ class Ptemcee(MCMCSampler):
             self.label,
             self.nsamples_effective,
             self.sampler,
+            self.discard,
             self.nburn,
             self.thin,
             self.search_parameter_keys,
             self.resume_file,
             self.log_likelihood_array,
+            self.log_posterior_array,
             self.chain_array,
             self.pos0,
             self.beta_list,
             self.tau_list,
             self.tau_list_n,
+            self.Q_list,
             self.time_per_check,
         )
 
-        if plot and not np.isnan(self.nburn):
-            # Generate the walkers plot diagnostic
-            plot_walkers(
-                self.chain_array[:, : self.iteration, :],
-                self.nburn,
-                self.thin,
-                self.search_parameter_keys,
-                self.outdir,
-                self.label,
-            )
+        if plot:
+            try:
+                # Generate the walkers plot diagnostic
+                plot_walkers(
+                    self.chain_array[:, : self.iteration, :],
+                    self.nburn,
+                    self.thin,
+                    self.search_parameter_keys,
+                    self.outdir,
+                    self.label,
+                    self.discard,
+                )
+            except Exception as e:
+                logger.info("Walkers plot failed with exception {}".format(e))
 
-            # Generate the tau plot diagnostic
-            plot_tau(
-                self.tau_list_n,
-                self.tau_list,
-                self.search_parameter_keys,
-                self.outdir,
-                self.label,
-                self.tau_int,
-                self.convergence_inputs.autocorr_tau,
-            )
+            try:
+                # Generate the tau plot diagnostic if DEBUG
+                if logger.level < logging.INFO:
+                    plot_tau(
+                        self.tau_list_n,
+                        self.tau_list,
+                        self.search_parameter_keys,
+                        self.outdir,
+                        self.label,
+                        self.tau_int,
+                        self.convergence_inputs.autocorr_tau,
+                    )
+            except Exception as e:
+                logger.info("tau plot failed with exception {}".format(e))
+
+            try:
+                plot_mean_log_posterior(
+                    self.mean_log_posterior,
+                    self.outdir,
+                    self.label,
+                )
+            except Exception as e:
+                logger.info("mean_logl plot failed with exception {}".format(e))
+
+
+def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10):
+    nsteps = mean_array.shape[1]
+    if nsteps < nsteps_min:
+        return 0
+
+    min_it = 0
+    for x in mean_array:
+        maxl = np.max(x)
+        fracdiff = (maxl - x) / np.abs(maxl)
+        idxs = fracdiff < frac
+        if np.sum(idxs) > 0:
+            min_it = np.max([min_it, np.min(np.arange(len(idxs))[idxs])])
+    return min_it
 
 
 def check_iteration(
+    iteration,
     samples,
     sampler,
     convergence_inputs,
@@ -569,6 +692,8 @@ def check_iteration(
     beta_list,
     tau_list,
     tau_list_n,
+    Q_list,
+    mean_log_posterior,
 ):
     """ Per-iteration logic to calculate the convergence check
 
@@ -594,24 +719,16 @@ def check_iteration(
     nsamples_effective: int
         The effective number of samples after burning and thinning
     """
-    import emcee
 
     ci = convergence_inputs
-    nwalkers, iteration, ndim = samples.shape
-
-    # Compute ACT tau for 0-temperature chains
-    tau_array = np.zeros((nwalkers, ndim))
-    for ii in range(nwalkers):
-        for jj, key in enumerate(search_parameter_keys):
-            if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
-                continue
-            try:
-                tau_array[ii, jj] = emcee.autocorr.integrated_time(
-                    samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
-            except emcee.autocorr.AutocorrError:
-                tau_array[ii, jj] = np.inf
+    # Note: nsteps is the number of steps in the samples while iterations is
+    # the current iteration number. So iteration > nsteps by the number of
+    # of discards
+    nwalkers, nsteps, ndim = samples.shape
 
-    # Maximum over paramters, mean over walkers
+    tau_array = calculate_tau_array(samples, search_parameter_keys, ci)
+
+    # Maximum over parameters, mean over walkers
     tau = np.max(np.mean(tau_array, axis=0))
 
     # Apply multiplicitive safety factor
@@ -622,37 +739,80 @@ def check_iteration(
     tau_list.append(list(np.mean(tau_array, axis=0)))
     tau_list_n.append(iteration)
 
-    # Convert to an integer
-    tau_int = int(np.ceil(tau)) if not np.isnan(tau) else tau
+    Q = get_Q_convergence(samples)
+    Q_list.append(Q)
 
-    if np.isnan(tau_int) or np.isinf(tau_int):
+    if np.isnan(tau) or np.isinf(tau):
         print_progress(
             iteration, sampler, time_per_check, np.nan, np.nan,
-            np.nan, np.nan, False, convergence_inputs,
+            np.nan, np.nan, np.nan, False, convergence_inputs, Q,
         )
         return False, np.nan, np.nan, np.nan, np.nan
 
+    # Convert to an integer
+    tau_int = int(np.ceil(tau))
+
     # Calculate the effective number of samples available
     nburn = int(ci.burn_in_nact * tau_int)
     thin = int(np.max([1, ci.thin_by_nact * tau_int]))
     samples_per_check = nwalkers / thin
-    nsamples_effective = int(nwalkers * (iteration - nburn) / thin)
+    nsamples_effective = int(nwalkers * (nsteps - nburn) / thin)
 
     # Calculate convergence boolean
-    converged = ci.nsamples < nsamples_effective
-
-    # Calculate fractional change in tau from previous iteration
-    check_taus = np.array(tau_list[-tau_int * ci.autocorr_tau :])
-    taus_per_parameter = check_taus[-1, :]
-    if not np.any(np.isnan(check_taus)):
-        frac = (taus_per_parameter - check_taus) / taus_per_parameter
-        max_frac = np.max(frac)
-        tau_usable = np.all(frac < ci.frac_threshold)
+    converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
+    logger.debug("Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}"
+                 .format(Q < ci.Q_tol, ci.nsamples < nsamples_effective))
+
+    GRAD_WINDOW_LENGTH = nwalkers + 1
+    nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int])
+    lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check])
+    check_taus = np.array(tau_list[lower_tau_index :])
+    if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
+        gradient_tau = get_max_gradient(
+            check_taus, axis=0, window_length=11)
+
+        if gradient_tau < ci.gradient_tau:
+            logger.debug(
+                "tau usable as {} < gradient_tau={}"
+                .format(gradient_tau, ci.gradient_tau)
+            )
+            tau_usable = True
+        else:
+            logger.debug(
+                "tau not usable as {} > gradient_tau={}"
+                .format(gradient_tau, ci.gradient_tau)
+            )
+            tau_usable = False
+
+        check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:]
+        gradient_mean_log_posterior = get_max_gradient(
+            check_mean_log_posterior, axis=1, window_length=GRAD_WINDOW_LENGTH,
+            smooth=True)
+
+        if gradient_mean_log_posterior < ci.gradient_mean_log_posterior:
+            logger.debug(
+                "tau usable as {} < gradient_mean_log_posterior={}"
+                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+            )
+            tau_usable *= True
+        else:
+            logger.debug(
+                "tau not usable as {} > gradient_mean_log_posterior={}"
+                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+            )
+            tau_usable = False
+
     else:
-        max_frac = np.nan
+        logger.debug("ACT is nan")
+        gradient_tau = np.nan
+        gradient_mean_log_posterior = np.nan
         tau_usable = False
 
-    if iteration < tau_int * ci.autocorr_tol or tau_int < ci.min_tau:
+    if nsteps < tau_int * ci.autocorr_tol:
+        logger.debug("ACT less than autocorr_tol")
+        tau_usable = False
+    elif tau_int < ci.min_tau:
+        logger.debug("ACT less than min_tau")
         tau_usable = False
 
     # Print an update on the progress
@@ -663,14 +823,39 @@ def check_iteration(
         nsamples_effective,
         samples_per_check,
         tau_int,
-        max_frac,
+        gradient_tau,
+        gradient_mean_log_posterior,
         tau_usable,
         convergence_inputs,
+        Q
     )
     stop = converged and tau_usable
     return stop, nburn, thin, tau_int, nsamples_effective
 
 
+def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False):
+    if smooth:
+        x = scipy.signal.savgol_filter(
+            x, axis=axis, window_length=window_length, polyorder=3)
+    return np.max(scipy.signal.savgol_filter(
+        x, axis=axis, window_length=window_length, polyorder=polyorder,
+        deriv=1))
+
+
+def get_Q_convergence(samples):
+    nwalkers, nsteps, ndim = samples.shape
+    if nsteps > 1:
+        W = np.mean(np.var(samples, axis=1), axis=0)
+        per_walker_mean = np.mean(samples, axis=1)
+        mean = np.mean(per_walker_mean, axis=0)
+        B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0)
+        Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
+        Q_per_dim = np.sqrt(Vhat / W)
+        return np.max(Q_per_dim)
+    else:
+        return np.inf
+
+
 def print_progress(
     iteration,
     sampler,
@@ -678,17 +863,19 @@ def print_progress(
     nsamples_effective,
     samples_per_check,
     tau_int,
-    max_frac,
+    gradient_tau,
+    gradient_mean_log_posterior,
     tau_usable,
     convergence_inputs,
+    Q,
 ):
     # Setup acceptance string
     acceptance = sampler.acceptance_fraction[0, :]
-    acceptance_str = "{:1.2f}->{:1.2f}".format(np.min(acceptance), np.max(acceptance))
+    acceptance_str = "{:1.2f}-{:1.2f}".format(np.min(acceptance), np.max(acceptance))
 
     # Setup tswap acceptance string
     tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
-    tswap_acceptance_str = "{:1.2f}->{:1.2f}".format(
+    tswap_acceptance_str = "{:1.2f}-{:1.2f}".format(
         np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction)
     )
 
@@ -701,37 +888,59 @@ def print_progress(
 
     sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
 
-    if max_frac >= 0:
-        tau_str = "{}(+{:0.2f})".format(tau_int, max_frac)
-    else:
-        tau_str = "{}({:0.2f})".format(tau_int, max_frac)
+    tau_str = "{}(+{:0.2f},+{:0.2f})".format(
+        tau_int, gradient_tau, gradient_mean_log_posterior
+    )
+
     if tau_usable:
         tau_str = "={}".format(tau_str)
     else:
         tau_str = "!{}".format(tau_str)
 
+    Q_str = "{:0.2f}".format(Q)
+
     evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
 
     ncalls = "{:1.1e}".format(
         convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps)
     eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
-    samp_timing = "{:1.1f}ms/sm".format(1e3 * ave_time_per_check / samples_per_check)
-
-    print(
-        "{}| {}| nc:{}| a0:{}| swp:{}| n:{}<{}| tau{}| {}| {}".format(
-            iteration,
-            str(sampling_time).split(".")[0],
-            ncalls,
-            acceptance_str,
-            tswap_acceptance_str,
-            nsamples_effective,
-            convergence_inputs.nsamples,
-            tau_str,
-            eval_timing,
-            samp_timing,
-        ),
-        flush=True,
-    )
+
+    try:
+        print(
+            "{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}".format(
+                iteration,
+                str(sampling_time).split(".")[0],
+                ncalls,
+                acceptance_str,
+                tswap_acceptance_str,
+                nsamples_effective,
+                convergence_inputs.nsamples,
+                tau_str,
+                Q_str,
+                eval_timing,
+            ),
+            flush=True,
+        )
+    except OSError as e:
+        logger.debug("Failed to print iteration due to :{}".format(e))
+
+
+def calculate_tau_array(samples, search_parameter_keys, ci):
+    """ Compute ACT tau for 0-temperature chains """
+    import emcee
+    nwalkers, nsteps, ndim = samples.shape
+    tau_array = np.zeros((nwalkers, ndim)) + np.inf
+    if nsteps > 1:
+        for ii in range(nwalkers):
+            for jj, key in enumerate(search_parameter_keys):
+                if ci.ignore_keys_for_tau and ci.ignore_keys_for_tau in key:
+                    continue
+                try:
+                    tau_array[ii, jj] = emcee.autocorr.integrated_time(
+                        samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
+                except emcee.autocorr.AutocorrError:
+                    tau_array[ii, jj] = np.inf
+    return tau_array
 
 
 def checkpoint(
@@ -740,16 +949,19 @@ def checkpoint(
     label,
     nsamples_effective,
     sampler,
+    discard,
     nburn,
     thin,
     search_parameter_keys,
     resume_file,
     log_likelihood_array,
+    log_posterior_array,
     chain_array,
     pos0,
     beta_list,
     tau_list,
     tau_list_n,
+    Q_list,
     time_per_check,
 ):
     logger.info("Writing checkpoint and diagnostics")
@@ -758,7 +970,7 @@ def checkpoint(
     # Store the samples if possible
     if nsamples_effective > 0:
         filename = "{}/{}_samples.txt".format(outdir, label)
-        samples = np.array(chain_array)[:, nburn : iteration : thin, :].reshape(
+        samples = np.array(chain_array)[:, discard + nburn : iteration : thin, :].reshape(
             (-1, ndim)
         )
         df = pd.DataFrame(samples, columns=search_parameter_keys)
@@ -774,8 +986,10 @@ def checkpoint(
         beta_list=beta_list,
         tau_list=tau_list,
         tau_list_n=tau_list_n,
+        Q_list=Q_list,
         time_per_check=time_per_check,
         log_likelihood_array=log_likelihood_array,
+        log_posterior_array=log_posterior_array,
         chain_array=chain_array,
         pos0=pos0,
     )
@@ -786,17 +1000,33 @@ def checkpoint(
     logger.info("Finished writing checkpoint")
 
 
-def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
+def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
+                 discard=0):
     """ Method to plot the trace of the walkers in an ensemble MCMC plot """
     nwalkers, nsteps, ndim = walkers.shape
+    if np.isnan(nburn):
+        nburn = nsteps
+    if np.isnan(thin):
+        thin = 1
     idxs = np.arange(nsteps)
     fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
-    scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.05,)
+    scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.1,)
+
+    # Plot the fixed burn-in
+    if discard > 0:
+        for i, (ax, axh) in enumerate(axes):
+            ax.plot(
+                idxs[: discard],
+                walkers[:, : discard, i].T,
+                color="gray",
+                **scatter_kwargs
+            )
+
     # Plot the burn-in
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[: nburn + 1],
-            walkers[:, : nburn + 1, i].T,
+            idxs[discard: discard + nburn + 1],
+            walkers[:, discard: discard + nburn + 1, i].T,
             color="C1",
             **scatter_kwargs
         )
@@ -804,12 +1034,14 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
     # Plot the thinned posterior samples
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[nburn::thin],
-            walkers[:, nburn::thin, i].T,
+            idxs[discard + nburn::thin],
+            walkers[:, discard + nburn::thin, i].T,
             color="C0",
             **scatter_kwargs
         )
-        axh.hist(walkers[:, nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
+        axh.hist(walkers[:, discard + nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
+
+    for i, (ax, axh) in enumerate(axes):
         axh.set_xlabel(parameter_labels[i])
         ax.set_ylabel(parameter_labels[i])
 
@@ -820,25 +1052,43 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label):
 
 
 def plot_tau(
-    tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau
+    tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau,
 ):
     fig, ax = plt.subplots()
     for i, key in enumerate(search_parameter_keys):
         ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
-    ax.axvline(tau_list_n[-1] - tau * autocorr_tau)
     ax.set_xlabel("Iteration")
     ax.set_ylabel(r"$\langle \tau \rangle$")
     ax.legend()
+    fig.tight_layout()
     fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
     plt.close(fig)
 
 
-def compute_evidence(sampler, log_likelihood_array, outdir, label, nburn, thin,
+def plot_mean_log_posterior(mean_log_posterior, outdir, label):
+
+    ntemps, nsteps = mean_log_posterior.shape
+    ymax = np.max(mean_log_posterior)
+    ymin = np.min(mean_log_posterior[:, -100:])
+    ymax += 0.1 * (ymax - ymin)
+    ymin -= 0.1 * (ymax - ymin)
+
+    fig, ax = plt.subplots()
+    idxs = np.arange(nsteps)
+    ax.plot(idxs, mean_log_posterior.T)
+    ax.set(xlabel="Iteration", ylabel=r"$\langle\mathrm{log-posterior}\rangle$",
+           ylim=(ymin, ymax))
+    fig.tight_layout()
+    fig.savefig("{}/{}_checkpoint_meanlogposterior.png".format(outdir, label))
+    plt.close(fig)
+
+
+def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nburn, thin,
                      iteration, make_plots=True):
     """ Computes the evidence using thermodynamic integration """
     betas = sampler.betas
     # We compute the evidence without the burnin samples, but we do not thin
-    lnlike = log_likelihood_array[:, :, nburn : iteration]
+    lnlike = log_likelihood_array[:, :, discard + nburn : iteration]
     mean_lnlikes = np.mean(np.mean(lnlike, axis=1), axis=1)
 
     mean_lnlikes = mean_lnlikes[::-1]
@@ -911,6 +1161,29 @@ class LikePriorEvaluator(object):
     def __init__(self, search_parameter_keys, use_ratio=False):
         self.search_parameter_keys = search_parameter_keys
         self.use_ratio = use_ratio
+        self.periodic_set = False
+
+    def _setup_periodic(self):
+        self._periodic = [
+            priors[key].boundary == "periodic" for key in self.search_parameter_keys
+        ]
+        priors.sample()
+        self._minima = np.array([
+            priors[key].minimum for key in self.search_parameter_keys
+        ])
+        self._range = np.array([
+            priors[key].maximum for key in self.search_parameter_keys
+        ]) - self._minima
+        self.periodic_set = True
+
+    def _wrap_periodic(self, array):
+        if not self.periodic_set:
+            self._setup_periodic()
+        array[self._periodic] = np.mod(
+            array[self._periodic] - self._minima[self._periodic],
+            self._range[self._periodic]
+        ) + self._minima[self._periodic]
+        return array
 
     def logl(self, v_array):
         parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)}
diff --git a/test/core/sampler/ptemcee_test.py b/test/core/sampler/ptemcee_test.py
index 70abcb0af..08439c814 100644
--- a/test/core/sampler/ptemcee_test.py
+++ b/test/core/sampler/ptemcee_test.py
@@ -28,36 +28,36 @@ class TestPTEmcee(unittest.TestCase):
 
     def test_default_kwargs(self):
         expected = dict(
-            ntemps=20,
-            nwalkers=200,
+            ntemps=10,
+            nwalkers=100,
             Tmax=None,
             betas=None,
             a=2.0,
             adaptation_lag=10000,
             adaptation_time=100,
             random=None,
-            adapt=True,
+            adapt=False,
             swap_ratios=False,
         )
         self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs(self):
         expected = dict(
-            ntemps=20,
-            nwalkers=200,
+            ntemps=10,
+            nwalkers=100,
             Tmax=None,
             betas=None,
             a=2.0,
             adaptation_lag=10000,
             adaptation_time=100,
             random=None,
-            adapt=True,
+            adapt=False,
             swap_ratios=False,
         )
         for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
             del new_kwargs["nwalkers"]
-            new_kwargs[equiv] = 200
+            new_kwargs[equiv] = 100
             self.sampler.kwargs = new_kwargs
             self.assertDictEqual(expected, self.sampler.kwargs)
 
-- 
GitLab