diff --git a/AUTHORS.md b/AUTHORS.md
index 8acda7e024ce8d94b7c7e197367ff989b84f98ec..7bf380bd69fde382046bfcc339fef7294cb75e12 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -59,6 +59,7 @@ Moritz Huebner
 Nicola De Lillo
 Nikhil Sarin
 Nirban Bose
+Noah Wolfe
 Olivia Wilk
 Paul Easter
 Paul Lasky
diff --git a/README.rst b/README.rst
index 37b98613e0fb68572f378cfb9340b53856b3ac93..70bc2efa539a21c4614e3245a4ad7540f4d4bd6d 100644
--- a/README.rst
+++ b/README.rst
@@ -53,6 +53,7 @@ as requested in their associated documentation.
 * `pymultinest <https://github.com/JohannesBuchner/PyMultiNest>`__
 * `cpnest <https://github.com/johnveitch/cpnest>`__
 * `emcee <https://github.com/dfm/emcee>`__
+* `nessai <https://github.com/mj-will/nessai>`_
 * `ptemcee <https://github.com/willvousden/ptemcee>`__
 * `ptmcmcsampler <https://github.com/jellis18/PTMCMCSampler>`__
 * `pypolychord <https://github.com/PolyChord/PolyChordLite>`__
diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py
index d0d05037031383ff9a22a08898856e06a6ddbf8d..a0d3e72ff5dbcc7224e6006fcda38ad20cf2f370 100644
--- a/bilby/core/sampler/nessai.py
+++ b/bilby/core/sampler/nessai.py
@@ -1,10 +1,12 @@
 import os
+import sys
 
 import numpy as np
 from pandas import DataFrame
+from scipy.special import logsumexp
 
 from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, signal_wrapper
 
 
 class Nessai(NestedSampler):
@@ -19,8 +21,22 @@ class Nessai(NestedSampler):
     """
 
     _default_kwargs = None
+    _run_kwargs_list = None
     sampling_seed_key = "seed"
 
+    @property
+    def run_kwargs_list(self):
+        """List of kwargs used in the run method of :code:`FlowSampler`"""
+        if not self._run_kwargs_list:
+            from nessai.utils.bilbyutils import get_run_kwargs_list
+
+            self._run_kwargs_list = get_run_kwargs_list()
+            ignored_kwargs = ["save"]
+            for ik in ignored_kwargs:
+                if ik in self._run_kwargs_list:
+                    self._run_kwargs_list.remove(ik)
+        return self._run_kwargs_list
+
     @property
     def default_kwargs(self):
         """Default kwargs for nessai.
@@ -28,32 +44,38 @@ class Nessai(NestedSampler):
         Retrieves default values from nessai directly and then includes any
         bilby specific defaults. This avoids the need to update bilby when the
         defaults change or new kwargs are added to nessai.
+
+        Includes the following kwargs that are specific to bilby:
+
+        - :code:`nessai_log_level`: allows setting the logging level in nessai
+        - :code:`nessai_logging_stream`: allows setting the logging stream
+        - :code:`nessai_plot`: allows toggling the plotting in FlowSampler.run
         """
         if not self._default_kwargs:
-            from inspect import signature
-
-            from nessai.flowsampler import FlowSampler
-            from nessai.nestedsampler import NestedSampler
-            from nessai.proposal import AugmentedFlowProposal, FlowProposal
-
-            kwargs = {}
-            classes = [
-                AugmentedFlowProposal,
-                FlowProposal,
-                NestedSampler,
-                FlowSampler,
-            ]
-            for c in classes:
-                kwargs.update(
-                    {
-                        k: v.default
-                        for k, v in signature(c).parameters.items()
-                        if v.default is not v.empty
-                    }
-                )
+            from nessai.utils.bilbyutils import get_all_kwargs
+
+            kwargs = get_all_kwargs()
+
             # Defaults for bilby that will override nessai defaults
-            bilby_defaults = dict(output=None, exit_code=self.exit_code)
+            bilby_defaults = dict(
+                output=None,
+                exit_code=self.exit_code,
+                nessai_log_level=None,
+                nessai_logging_stream="stdout",
+                nessai_plot=True,
+                plot_posterior=False,  # bilby already produces a posterior plot
+                log_on_iteration=False,  # Use periodic logging by default
+                logging_interval=60,  # Log every 60 seconds
+            )
             kwargs.update(bilby_defaults)
+            # Kwargs that cannot be set in bilby
+            remove = [
+                "save",
+                "signal_handling",
+            ]
+            for k in remove:
+                if k in kwargs:
+                    kwargs.pop(k)
             self._default_kwargs = kwargs
         return self._default_kwargs
 
@@ -72,12 +94,10 @@ class Nessai(NestedSampler):
         """
         return self.priors.ln_prob(theta, axis=0)
 
-    def run_sampler(self):
-        from nessai.flowsampler import FlowSampler
-        from nessai.livepoint import dict_to_live_points, live_points_to_array
+    def get_nessai_model(self):
+        """Get the model for nessai."""
+        from nessai.livepoint import dict_to_live_points
         from nessai.model import Model as BaseModel
-        from nessai.posterior import compute_weights
-        from nessai.utils import setup_logger
 
         class Model(BaseModel):
             """A wrapper class to pass our log_likelihood and priors into nessai
@@ -124,47 +144,115 @@ class Nessai(NestedSampler):
                 """Proposal probability for new the point"""
                 return self.log_prior(x)
 
-        # Setup the logger for nessai using the same settings as the bilby logger
-        setup_logger(
-            self.outdir, label=self.label, log_level=logger.getEffectiveLevel()
-        )
+            @staticmethod
+            def from_unit_hypercube(x):
+                """Map samples from the unit hypercube to the prior."""
+                theta = {}
+                for n in self._search_parameter_keys:
+                    theta[n] = self.priors[n].rescale(x[n])
+                return dict_to_live_points(theta)
+
+            @staticmethod
+            def to_unit_hypercube(x):
+                """Map samples from the prior to the unit hypercube."""
+                theta = {n: x[n] for n in self._search_parameter_keys}
+                return dict_to_live_points(self.priors.cdf(theta))
+
         model = Model(self.search_parameter_keys, self.priors)
-        try:
-            out = FlowSampler(model, **self.kwargs)
-            out.run(save=True, plot=self.plot)
-        except TypeError as e:
-            raise TypeError(f"Unable to initialise nessai sampler with error: {e}")
-        except (SystemExit, KeyboardInterrupt) as e:
-            import sys
-
-            logger.info(
-                f"Caught {type(e).__name__} with args {e.args}, "
-                f"exiting with signal {self.exit_code}"
-            )
-            sys.exit(self.exit_code)
+        return model
+
+    def split_kwargs(self):
+        """Split kwargs into configuration and run time kwargs"""
+        kwargs = self.kwargs.copy()
+        run_kwargs = {}
+        for k in self.run_kwargs_list:
+            run_kwargs[k] = kwargs.pop(k)
+        run_kwargs["plot"] = kwargs.pop("nessai_plot")
+        return kwargs, run_kwargs
+
+    def get_posterior_weights(self):
+        """Get the posterior weights for the nested samples"""
+        from nessai.posterior import compute_weights
+
+        _, log_weights = compute_weights(
+            np.array(self.fs.nested_samples["logL"]),
+            np.array(self.fs.ns.state.nlive),
+        )
+        w = np.exp(log_weights - logsumexp(log_weights))
+        return w
+
+    def get_nested_samples(self):
+        """Get the nested samples dataframe"""
+        ns = DataFrame(self.fs.nested_samples)
+        ns.rename(
+            columns=dict(logL="log_likelihood", logP="log_prior", it="iteration"),
+            inplace=True,
+        )
+        return ns
+
+    def update_result(self):
+        """Update the result object."""
+        from nessai.livepoint import live_points_to_array
 
         # Manually set likelihood evaluations because parallelisation breaks the counter
-        self.result.num_likelihood_evaluations = out.ns.likelihood_evaluations[-1]
+        self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations
 
         self.result.samples = live_points_to_array(
-            out.posterior_samples, self.search_parameter_keys
+            self.fs.posterior_samples, self.search_parameter_keys
         )
-        self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
-        self.result.nested_samples = DataFrame(out.nested_samples)
-        self.result.nested_samples.rename(
-            columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True
+        self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"]
+        self.result.nested_samples = self.get_nested_samples()
+        self.result.nested_samples["weights"] = self.get_posterior_weights()
+        self.result.log_evidence = self.fs.log_evidence
+        self.result.log_evidence_err = self.fs.log_evidence_error
+
+    @signal_wrapper
+    def run_sampler(self):
+        """Run the sampler.
+
+        Nessai is designed to be ran in two stages, initialise the sampler
+        and then call the run method with additional configuration. This means
+        there are effectively two sets of keyword arguments: one for
+        initializing the sampler and the other for the run function.
+        """
+        from nessai.flowsampler import FlowSampler
+        from nessai.utils import setup_logger
+
+        kwargs, run_kwargs = self.split_kwargs()
+
+        # Setup the logger for nessai, use nessai_log_level if specified, else use
+        # the level of the bilby logger.
+        nessai_log_level = kwargs.pop("nessai_log_level")
+        if nessai_log_level is None or nessai_log_level == "bilby":
+            nessai_log_level = logger.getEffectiveLevel()
+        nessai_logging_stream = kwargs.pop("nessai_logging_stream")
+
+        setup_logger(
+            self.outdir,
+            label=self.label,
+            log_level=nessai_log_level,
+            stream=nessai_logging_stream,
         )
-        _, log_weights = compute_weights(
-            np.array(self.result.nested_samples.log_likelihood),
-            np.array(out.ns.state.nlive),
+
+        # Get the nessai model
+        model = self.get_nessai_model()
+
+        # Configure the sampler
+        self.fs = FlowSampler(
+            model,
+            signal_handling=False,  # Disable signal handling so it can be handled by bilby
+            **kwargs,
         )
-        self.result.nested_samples["weights"] = np.exp(log_weights)
-        self.result.log_evidence = out.ns.log_evidence
-        self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive)
+        # Run the sampler
+        self.fs.run(**run_kwargs)
+
+        # Update the result
+        self.update_result()
 
         return self.result
 
     def _translate_kwargs(self, kwargs):
+        """Translate the keyword arguments"""
         super()._translate_kwargs(kwargs)
         if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
@@ -178,10 +266,7 @@ class Nessai(NestedSampler):
                 kwargs["n_pool"] = self._npool
 
     def _verify_kwargs_against_default_kwargs(self):
-        """
-        Set the directory where the output will be written
-        and check resume and checkpoint status.
-        """
+        """Verify the keyword arguments"""
         if "config_file" in self.kwargs:
             d = load_json(self.kwargs["config_file"], None)
             self.kwargs.update(d)
@@ -190,10 +275,6 @@ class Nessai(NestedSampler):
         if not self.kwargs["plot"]:
             self.kwargs["plot"] = self.plot
 
-        if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1:
-            logger.warning("Setting pool to None (n_pool=1 & max_threads=1)")
-            self.kwargs["n_pool"] = None
-
         if not self.kwargs["output"]:
             self.kwargs["output"] = os.path.join(
                 self.outdir, f"{self.label}_nessai", ""
@@ -202,5 +283,21 @@ class Nessai(NestedSampler):
         check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
+    def write_current_state(self):
+        """Write the current state of the sampler"""
+        self.fs.ns.checkpoint()
+
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Overwrites the base class to make sure that :code:`Nessai` terminates
+        properly.
+        """
+        if hasattr(self, "fs"):
+            self.fs.terminate_run(code=signum)
+        else:
+            logger.warning("Sampler is not initialized")
+        self._log_interruption(signum=signum)
+        sys.exit(self.exit_code)
+
     def _setup_pool(self):
         pass
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 2534b0369d1de8d1b75e8561f3cafb4cda26ec73..85b253b1e97fd33d9a55b88bf68771697af029d6 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -47,7 +47,7 @@ class Ptemcee(MCMCSampler):
     list commonly used kwargs and the bilby defaults.
 
     Parameters
-    ==========
+    ----------
     nsamples: int, (5000)
         The requested number of samples. Note, in cases where the
         autocorrelation parameter is difficult to measure, it is possible to
@@ -116,7 +116,7 @@ class Ptemcee(MCMCSampler):
 
 
     Other Parameters
-    ================
+    ----------------
     nwalkers: int, (200)
         The number of walkers
     nsteps: int, (100)
@@ -296,7 +296,7 @@ class Ptemcee(MCMCSampler):
         """Draw the initial positions from the prior
 
         Returns
-        =======
+        -------
         pos0: list
             The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
 
@@ -315,7 +315,7 @@ class Ptemcee(MCMCSampler):
         See pos0 in the class initialization for details.
 
         Returns
-        =======
+        -------
         pos0: list
             The initial postitions of the walkers, with shape (ntemps, nwalkers, ndim)
 
@@ -504,6 +504,7 @@ class Ptemcee(MCMCSampler):
 
         t0 = datetime.datetime.now()
         logger.info("Starting to sample")
+
         while True:
             for (pos0, log_posterior, log_likelihood) in sampler.sample(
                 self.pos0,
@@ -531,6 +532,7 @@ class Ptemcee(MCMCSampler):
                 )
 
             self.pos0 = pos0
+
             self.chain_array[:, self.iteration, :] = pos0[0, :, :]
             self.log_likelihood_array[:, :, self.iteration] = log_likelihood
             self.log_posterior_array[:, :, self.iteration] = log_posterior
@@ -538,6 +540,9 @@ class Ptemcee(MCMCSampler):
                 self.log_posterior_array[:, :, : self.iteration], axis=1
             )
 
+            # (nwalkers, ntemps, iterations)
+            # so mean_log_posterior is shaped (nwalkers, iterations)
+
             # Calculate time per iteration
             self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
             t0 = datetime.datetime.now()
@@ -725,25 +730,87 @@ def check_iteration(
     beta_list,
     tau_list,
     tau_list_n,
-    Q_list,
+    gelman_rubin_list,
     mean_log_posterior,
     verbose=True,
 ):
-    """Per-iteration logic to calculate the convergence check
+    """Per-iteration logic to calculate the convergence check.
+
+    To check convergence, this function does the following:
+    1. Calculate the autocorrelation time (tau) for each dimension for each walker,
+       corresponding to those dimensions in search_parameter_keys that aren't
+       specifically excluded in ci.ignore_keys_for_tau.
+        a. Store the average tau for each dimension, averaged over each walker.
+    2. Calculate the Gelman-Rubin statistic (see `get_Q_convergence`), measuring
+       the convergence of the ensemble of walkers.
+    3. Calculate the number of effective samples; we aggregate the total number
+       of burned-in samples (amongst all walkers), divided by a multiple of the
+       current maximum average autocorrelation time. Tuned by `ci.burn_in_nact`
+       and `ci.thin_by_nact`.
+    4. If the Gelman-Rubin statistic < `ci.Q_tol` and `ci.nsamples` < the
+       number of effective samples, we say that our ensemble is converged,
+       setting `converged = True`.
+    5. For some number of the latest steps (set by the autocorrelation time
+       and the GRAD_WINDOW_LENGTH parameter), we find the maxmium gradient
+       of the autocorrelation time over all of our dimensions, over all walkers
+       (autocorrelation time is already averaged over walkers) and the maximum
+       value of the gradient of the mean log posterior over iterations, over
+       all walkers.
+    6. If the maximum gradient in tau is less than `ci.gradient_tau` and the
+       maximum gradient in the mean log posterior is less than
+       `ci.gradient_mean_log_posterior`, we set `tau_usable = True`.
+    7. If both `converged` and `tau_usable` are true, we return `stop = True`,
+       indicating that our ensemble is converged + burnt in on this
+       iteration.
+    8. Also prints progress! (see `print_progress`)
+
+    Notes
+    -----
+    The gradient of tau is computed with a Savgol-Filter, over windows in
+    sample number of length `GRAD_WINDOW_LENGTH`. This value must be an odd integer.
+    For `ndim > 3`, we calculate this as the nearest odd integer to ndim.
+    For `ndim <= 3`, we calculate this as the nearest odd integer to nwalkers, as
+    typically a much larger window length than polynomial order (default 2) leads
+    to more stable smoothing.
 
     Parameters
-    ==========
+    ----------
+    iteration: int
+        Number indexing the current iteration, at which we are checking
+        convergence.
+    samples: np.ndarray
+        Array of ensemble MCMC samples, shaped like (number of walkers, number
+        of MCMC steps, number of dimensions).
+    sampler: bilby.core.sampler.Ptemcee
+        Bilby Ptemcee sampler object; in particular, this function uses the list
+        of walker temperatures stored in `sampler.betas`.
     convergence_inputs: bilby.core.sampler.ptemcee.ConvergenceInputs
         A named tuple of the convergence checking inputs
     search_parameter_keys: list
         A list of the search parameter keys
     time_per_check, tau_list, tau_list_n: list
         Lists used for tracking the run
+    beta_list: list
+        List of floats storing the walker inverse temperatures.
+    tau_list: list
+        List of average autocorrelation times for each dimension, averaged
+        over walkers, at each checked iteration. So, an effective shape
+        of (number of iterations so far, number of dimensions).
+    tau_list_n: list
+        List of iteration numbers, enumerating the first "axis" of tau_list.
+        E.g. if tau_list_n[1] = 5, this means that the list found at
+        tau_list[1] was calculated on iteration number 5.
+    gelman_rubin_list: list (floats)
+        list of values of the Gelman-Rubin statistic; the value calculated
+        in this call of check_iteration is appended to the gelman_rubin_list.
+    mean_log_posterior: np.ndarray
+        Float array shaped like (number of walkers, number of MCMC steps),
+        with the log of the posterior, averaged over the dimensions.
     verbose: bool
         Whether to print the output
 
     Returns
-    =======
+    -------
     stop: bool
         A boolean flag, True if the stopping criteria has been met
     burn: int
@@ -757,14 +824,9 @@ def check_iteration(
     """
 
     ci = convergence_inputs
-    # Note: nsteps is the number of steps in the samples while iterations is
-    # the current iteration number. So iteration > nsteps by the number of
-    # of discards
-    nwalkers, nsteps, ndim = samples.shape
 
+    nwalkers, nsteps, ndim = samples.shape
     tau_array = calculate_tau_array(samples, search_parameter_keys, ci)
-
-    # Maximum over parameters, mean over walkers
     tau = np.max(np.mean(tau_array, axis=0))
 
     # Apply multiplicitive safety factor
@@ -775,8 +837,8 @@ def check_iteration(
     tau_list.append(list(np.mean(tau_array, axis=0)))
     tau_list_n.append(iteration)
 
-    Q = get_Q_convergence(samples)
-    Q_list.append(Q)
+    gelman_rubin_statistic = get_Q_convergence(samples)
+    gelman_rubin_list.append(gelman_rubin_statistic)
 
     if np.isnan(tau) or np.isinf(tau):
         if verbose:
@@ -791,7 +853,7 @@ def check_iteration(
                 np.nan,
                 False,
                 convergence_inputs,
-                Q,
+                gelman_rubin_statistic,
             )
         return False, np.nan, np.nan, np.nan, np.nan
 
@@ -805,18 +867,24 @@ def check_iteration(
     nsamples_effective = int(nwalkers * (nsteps - nburn) / thin)
 
     # Calculate convergence boolean
-    converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
+    converged = gelman_rubin_statistic < ci.Q_tol and ci.nsamples < nsamples_effective
     logger.debug(
-        f"Convergence: Q<Q_tol={Q < ci.Q_tol}, "
-        f"nsamples<nsamples_effective={ci.nsamples < nsamples_effective}"
+        "Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}".format(
+            gelman_rubin_statistic < ci.Q_tol, ci.nsamples < nsamples_effective
+        )
     )
 
-    GRAD_WINDOW_LENGTH = nwalkers + 1
+    GRAD_WINDOW_LENGTH = 2 * ((ndim + 1) // 2) + 1
+    if GRAD_WINDOW_LENGTH <= 3:
+        GRAD_WINDOW_LENGTH = 2 * (nwalkers // 2) + 1
+
     nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int])
     lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check])
     check_taus = np.array(tau_list[lower_tau_index:])
     if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
-        gradient_tau = get_max_gradient(check_taus, axis=0, window_length=11)
+        gradient_tau = get_max_gradient(
+            check_taus, axis=0, window_length=GRAD_WINDOW_LENGTH
+        )
 
         if gradient_tau < ci.gradient_tau:
             logger.debug(
@@ -876,13 +944,51 @@ def check_iteration(
             gradient_mean_log_posterior,
             tau_usable,
             convergence_inputs,
-            Q,
+            gelman_rubin_statistic,
         )
+
     stop = converged and tau_usable
     return stop, nburn, thin, tau_int, nsamples_effective
 
 
 def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False):
+    """Calculate the maximum value of the gradient in the input data.
+
+    Applies a Savitzky-Golay filter (`scipy.signal.savgol_filter`) to the input
+    data x, along a particular axis. This filter smooths the data and, as configured
+    in this function, simultaneously calculates the derivative of the smoothed data.
+    If smooth=True is provided, it will apply a Savitzky-Golay filter with a
+    polynomial order of 3 to the input data before applying this filter a second
+    time and calculating the derivative. This function will return the maximum value
+    of the derivative returned by the filter.
+
+    See https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.savgol_filter.html
+    for more information on the Savitzky-Golay filter that we use. Some parameter
+    documentation has been borrowed from this source.
+
+    Parameters
+    ----------
+    x : np.ndarray
+        Array of input data (can be int or float, as `savgol_filter` casts to float).
+    axis : int, default = 0
+        The axis of the input array x over which to calculate the gradient.
+    window_length : int, default = 11
+        The length of the filter window (i.e., the number of coefficients to use
+        in approximating the data).
+    polyorder : int, default = 2
+        The order of the polynomial used to fit the samples. polyorder must be less
+        than window_length.
+    smooth : bool, default = False
+        If true, this will smooth the data with a Savitzky-Golay filter before
+        providing it to the Savitzky-Golay filter for calculating the derviative.
+        Probably useful if you think your input data is especially noisy.
+
+    Returns
+    -------
+    max_gradient : float
+        Maximum value of the gradient.
+    """
+
     from scipy.signal import savgol_filter
 
     if smooth:
@@ -895,12 +1001,54 @@ def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False):
 
 
 def get_Q_convergence(samples):
+    """Calculate the Gelman-Rubin statistic as an estimate of convergence for
+    an ensemble of MCMC walkers.
+
+    Calculates the Gelman-Rubin statistic, from Gelman and Rubin (1992).
+    See section 2.2 of Gelman and Rubin (1992), at
+    https://doi.org/10.1214/ss/1177011136.
+
+    There is also a good description of this statistic in section 7.4.2
+    of "Advanced Statistical Computing" (Peng 2021), in-progress course notes,
+    currently found at
+    https://bookdown.org/rdpeng/advstatcomp/monitoring-convergence.html.
+    As of this writing, L in this resource refers to the number of sampling
+    steps remaining after some have been discarded to achieve burn-in,
+    equivalent to nsteps here. Paraphrasing, we compare the variance between
+    our walkers (chains) to the variance within each walker (compare
+    inter-walker vs. intra-walker variance). We do this because our walkers
+    should be indistinguishable from one another when they reach a steady-state,
+    i.e. convergence. Looking at V-hat in the definition of this function, we
+    can see that as nsteps -> infinity, B (inter-chain variance) -> 0,
+    R -> 1; so, R >~ 1 is a good condition for the convergence of our ensemble.
+
+    In practice, this function calculates the Gelman-Rubin statistic for
+    each dimension, and then returns the largest of these values. This
+    means that we can be sure that, once the walker with the largest such value
+    achieves a Gelman-Rubin statistic of >~ 1, the others have as well.
+
+    Parameters
+    ----------
+    samples: np.ndarray
+        Array of ensemble MCMC samples, shaped like (number of walkers, number
+        of MCMC steps, number of dimensions).
+
+    Returns
+    -------
+    Q: float
+        The largest value of the Gelman-Rubin statistic, from those values
+        calculated for each dimension. If only one step is represented in
+        samples, this returns np.inf.
+    """
+
     nwalkers, nsteps, ndim = samples.shape
     if nsteps > 1:
         W = np.mean(np.var(samples, axis=1), axis=0)
+
         per_walker_mean = np.mean(samples, axis=1)
         mean = np.mean(per_walker_mean, axis=0)
         B = nsteps / (nwalkers - 1.0) * np.sum((per_walker_mean - mean) ** 2, axis=0)
+
         Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
         Q_per_dim = np.sqrt(Vhat / W)
         return np.max(Q_per_dim)
@@ -977,7 +1125,31 @@ def print_progress(
 
 
 def calculate_tau_array(samples, search_parameter_keys, ci):
-    """Compute ACT tau for 0-temperature chains"""
+    """Calculate the autocorrelation time for zero-temperature chains.
+
+    Calculates the autocorrelation time for each chain, for those parameters/
+    dimensions that are not explicitly excluded in ci.ignore_keys_for_tau.
+
+    Parameters
+    ----------
+    samples: np.ndarray
+        Array of ensemble MCMC samples, shaped like (number of walkers, number
+        of MCMC steps, number of dimensions).
+    search_parameter_keys: list
+        A list of the search parameter keys
+    ci : collections.namedtuple
+        Collection of settings for convergence tests, including autocorrelation
+        calculation. If a value in search_parameter_keys is included in
+        ci.ignore_keys_for_tau, this function will not calculate an
+        autocorrelation time for any walker along that particular dimension.
+
+    Returns
+    -------
+    tau_array: np.ndarray
+        Float array shaped like (nwalkers, ndim) (with all np.inf for any
+        dimension that is excluded by ci.ignore_keys_for_tau).
+    """
+
     import emcee
 
     nwalkers, nsteps, ndim = samples.shape
@@ -1141,9 +1313,11 @@ def plot_tau(
 def plot_mean_log_posterior(mean_log_posterior, outdir, label):
     import matplotlib.pyplot as plt
 
+    mean_log_posterior[mean_log_posterior < -1e100] = np.nan
+
     ntemps, nsteps = mean_log_posterior.shape
-    ymax = np.max(mean_log_posterior)
-    ymin = np.min(mean_log_posterior[:, -100:])
+    ymax = np.nanmax(mean_log_posterior)
+    ymin = np.nanmin(mean_log_posterior[:, -100:])
     ymax += 0.1 * (ymax - ymin)
     ymin -= 0.1 * (ymax - ymin)
 
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index a78c17629939af8d85870694a8f568423e855364..3d207e4833ee7facb3df9ee25791bbc0ebb5e1de 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -830,6 +830,9 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion,
     output_sample = fill_from_fixed_priors(output_sample, priors)
     output_sample, _ = base_conversion(output_sample)
     if likelihood is not None:
+        compute_per_detector_log_likelihoods(
+            samples=output_sample, likelihood=likelihood, npool=npool)
+
         marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
         if len(marginalized_parameters) > 0:
             try:
@@ -1466,6 +1469,117 @@ def _compute_snrs(args):
     return snrs
 
 
+def compute_per_detector_log_likelihoods(samples, likelihood, npool=1, block=10):
+    """
+    Calculate the log likelihoods in each detector.
+
+    Parameters
+    ==========
+    samples: DataFrame
+        Posterior from run with a marginalised likelihood.
+    likelihood: bilby.gw.likelihood.GravitationalWaveTransient
+        Likelihood used during sampling.
+    npool: int, (default=1)
+        If given, perform generation (where possible) using a multiprocessing pool
+    block: int, (default=10)
+        Size of the blocks to use in multiprocessing
+
+    Returns
+    =======
+    sample: DataFrame
+        Returns the posterior with new samples.
+    """
+    if likelihood is not None:
+        if not callable(likelihood.compute_per_detector_log_likelihood):
+            logger.debug('Not computing per-detector log likelihoods.')
+            return samples
+
+        if isinstance(samples, dict):
+            likelihood.parameters.update(samples)
+            samples = likelihood.compute_per_detector_log_likelihood()
+            return samples
+
+        elif not isinstance(samples, DataFrame):
+            raise ValueError("Unable to handle input samples of type {}".format(type(samples)))
+        from tqdm.auto import tqdm
+
+        logger.info('Computing per-detector log likelihoods.')
+
+        # Initialize cache dict
+        cached_samples_dict = dict()
+
+        # Store samples to convert for checking
+        cached_samples_dict["_samples"] = samples
+
+        # Set up the multiprocessing
+        if npool > 1:
+            from ..core.sampler.base_sampler import _initialize_global_variables
+            pool = multiprocessing.Pool(
+                processes=npool,
+                initializer=_initialize_global_variables,
+                initargs=(likelihood, None, None, False),
+            )
+            logger.info(
+                "Using a pool with size {} for nsamples={}"
+                .format(npool, len(samples))
+            )
+        else:
+            from ..core.sampler.base_sampler import _sampling_convenience_dump
+            _sampling_convenience_dump.likelihood = likelihood
+            pool = None
+
+        fill_args = [(ii, row) for ii, row in samples.iterrows()]
+        ii = 0
+        pbar = tqdm(total=len(samples), file=sys.stdout)
+        while ii < len(samples):
+            if ii in cached_samples_dict:
+                ii += block
+                pbar.update(block)
+                continue
+
+            if pool is not None:
+                subset_samples = pool.map(_compute_per_detector_log_likelihoods,
+                                          fill_args[ii: ii + block])
+            else:
+                subset_samples = [list(_compute_per_detector_log_likelihoods(xx))
+                                  for xx in fill_args[ii: ii + block]]
+
+            cached_samples_dict[ii] = subset_samples
+
+            ii += block
+            pbar.update(len(subset_samples))
+        pbar.close()
+
+        if pool is not None:
+            pool.close()
+            pool.join()
+
+        new_samples = np.concatenate(
+            [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"]
+        )
+
+        for ii, key in \
+                enumerate([f'{ifo.name}_log_likelihood' for ifo in likelihood.interferometers]):
+            samples[key] = new_samples[:, ii]
+
+        return samples
+
+    else:
+        logger.debug('Not computing per-detector log likelihoods.')
+
+
+def _compute_per_detector_log_likelihoods(args):
+    """A wrapper of computing the per-detector log likelihoods to enable multiprocessing"""
+    from ..core.sampler.base_sampler import _sampling_convenience_dump
+    likelihood = _sampling_convenience_dump.likelihood
+    ii, sample = args
+    sample = dict(sample).copy()
+    likelihood.parameters.update(dict(sample).copy())
+    new_sample = likelihood.compute_per_detector_log_likelihood()
+    return tuple((new_sample[key] for key in
+                  [f'{ifo.name}_log_likelihood' for ifo in likelihood.interferometers]))
+
+
 def generate_posterior_samples_from_marginalized_likelihood(
         samples, likelihood, npool=1, block=10, use_cache=True):
     """
diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 6f4456ad3d799711212bafe4607e6af455a01cfb..9f8c29147aec9fab09f7a7dc63882588c44fd2e0 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -627,7 +627,12 @@ class Interferometer(object):
         return self.strain_data.frequency_domain_strain / self.amplitude_spectral_density_array
 
     def save_data(self, outdir, label=None):
-        """ Creates a save file for the data in plain text format
+        """ Creates save files for interferometer data in plain text format.
+
+        Saves two files: the frequency domain strain data with three columns [f, real part of h(f),
+        imaginary part of h(f)], and the amplitude spectral density with two columns [f, ASD(f)].
+
+        Note that in v1.3.0 and below, the ASD was saved in a file called *_psd.dat.
 
         Parameters
         ==========
@@ -638,10 +643,10 @@ class Interferometer(object):
         """
 
         if label is None:
-            filename_psd = '{}/{}_psd.dat'.format(outdir, self.name)
+            filename_asd = '{}/{}_asd.dat'.format(outdir, self.name)
             filename_data = '{}/{}_frequency_domain_data.dat'.format(outdir, self.name)
         else:
-            filename_psd = '{}/{}_{}_psd.dat'.format(outdir, self.name, label)
+            filename_asd = '{}/{}_{}_asd.dat'.format(outdir, self.name, label)
             filename_data = '{}/{}_{}_frequency_domain_data.dat'.format(outdir, self.name, label)
         np.savetxt(filename_data,
                    np.array(
@@ -649,7 +654,7 @@ class Interferometer(object):
                         self.strain_data.frequency_domain_strain.real,
                         self.strain_data.frequency_domain_strain.imag]).T,
                    header='f real_h(f) imag_h(f)')
-        np.savetxt(filename_psd,
+        np.savetxt(filename_asd,
                    np.array(
                        [self.strain_data.frequency_array,
                         self.amplitude_spectral_density_array]).T,
diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index 26561fca0d5f5ac895ad2d2470e98a4217e779b0..34c71390f25032441deb146aa2ae3947e2f4db18 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -115,6 +115,36 @@ class GravitationalWaveTransient(Likelihood):
         optimal_snr_squared_array = attr.ib()
         d_inner_h_squared_tc_array = attr.ib()
 
+        def __add__(self, other_snr):
+
+            total_d_inner_h = self.d_inner_h + other_snr.d_inner_h
+            total_optimal_snr_squared = self.optimal_snr_squared + \
+                np.real(other_snr.optimal_snr_squared)
+            total_complex_matched_filter_snr = self.complex_matched_filter_snr + \
+                other_snr.complex_matched_filter_snr
+
+            total_d_inner_h_array = self.d_inner_h_array
+            if other_snr.d_inner_h_array is not None \
+                    and self.d_inner_h_array is not None:
+                total_d_inner_h_array += other_snr.d_inner_h_array
+
+            total_optimal_snr_squared_array = self.optimal_snr_squared_array
+            if other_snr.optimal_snr_squared_array is not None \
+                    and self.optimal_snr_squared_array is not None:
+                total_optimal_snr_squared_array += other_snr.optimal_snr_squared_array
+
+            total_d_inner_h_squared_tc_array = self.d_inner_h_squared_tc_array
+            if other_snr.d_inner_h_squared_tc_array is not None \
+                    and self.d_inner_h_squared_tc_array is not None:
+                total_d_inner_h_squared_tc_array += other_snr.d_inner_h_squared_tc_array
+
+            return self.__class__(d_inner_h=total_d_inner_h,
+                                  optimal_snr_squared=total_optimal_snr_squared,
+                                  complex_matched_filter_snr=total_complex_matched_filter_snr,
+                                  d_inner_h_array=total_d_inner_h_array,
+                                  optimal_snr_squared_array=total_optimal_snr_squared_array,
+                                  d_inner_h_squared_tc_array=total_d_inner_h_squared_tc_array)
+
     def __init__(
             self, interferometers, waveform_generator, time_marginalization=False,
             distance_marginalization=False, phase_marginalization=False, calibration_marginalization=False, priors=None,
@@ -366,79 +396,99 @@ class GravitationalWaveTransient(Likelihood):
         waveform_polarizations = \
             self.waveform_generator.frequency_domain_strain(self.parameters)
 
+        if self.time_marginalization and self.jitter_time:
+            self.parameters['geocent_time'] += self.parameters['time_jitter']
+
         self.parameters.update(self.get_sky_frame_parameters())
 
         if waveform_polarizations is None:
             return np.nan_to_num(-np.inf)
 
-        d_inner_h = 0.
-        optimal_snr_squared = 0.
-        complex_matched_filter_snr = 0.
+        total_snrs = self._CalculatedSNRs(
+            d_inner_h=0., optimal_snr_squared=0., complex_matched_filter_snr=0.,
+            d_inner_h_array=None, optimal_snr_squared_array=None, d_inner_h_squared_tc_array=None)
 
         if self.time_marginalization and self.calibration_marginalization:
-            if self.jitter_time:
-                self.parameters['geocent_time'] += self.parameters['time_jitter']
-
-            d_inner_h_array = np.zeros(
+            total_snrs.d_inner_h_array = np.zeros(
                 (self.number_of_response_curves, len(self.interferometers.frequency_array[0:-1])),
                 dtype=np.complex128)
-            optimal_snr_squared_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
+            total_snrs.optimal_snr_squared_array = \
+                np.zeros(self.number_of_response_curves, dtype=np.complex128)
 
         elif self.time_marginalization:
-            if self.jitter_time:
-                self.parameters['geocent_time'] += self.parameters['time_jitter']
-            d_inner_h_array = np.zeros(len(self._times), dtype=np.complex128)
+            total_snrs.d_inner_h_array = np.zeros(len(self._times), dtype=np.complex128)
 
         elif self.calibration_marginalization:
-            d_inner_h_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
-            optimal_snr_squared_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
+            total_snrs.d_inner_h_array = \
+                np.zeros(self.number_of_response_curves, dtype=np.complex128)
+            total_snrs.optimal_snr_squared_array = \
+                np.zeros(self.number_of_response_curves, dtype=np.complex128)
 
         for interferometer in self.interferometers:
             per_detector_snr = self.calculate_snrs(
                 waveform_polarizations=waveform_polarizations,
                 interferometer=interferometer)
 
-            d_inner_h += per_detector_snr.d_inner_h
-            optimal_snr_squared += np.real(per_detector_snr.optimal_snr_squared)
-            complex_matched_filter_snr += per_detector_snr.complex_matched_filter_snr
+            total_snrs += per_detector_snr
+
+        log_l = self.compute_log_likelihood_from_snrs(total_snrs)
 
-            if self.time_marginalization or self.calibration_marginalization:
-                d_inner_h_array += per_detector_snr.d_inner_h_array
+        if self.time_marginalization and self.jitter_time:
+            self.parameters['geocent_time'] -= self.parameters['time_jitter']
 
-            if self.calibration_marginalization:
-                optimal_snr_squared_array += per_detector_snr.optimal_snr_squared_array
+        return float(log_l.real)
+
+    def compute_log_likelihood_from_snrs(self, total_snrs):
 
         if self.calibration_marginalization and self.time_marginalization:
             log_l = self.time_and_calibration_marginalized_likelihood(
-                d_inner_h_array=d_inner_h_array,
-                h_inner_h=optimal_snr_squared_array)
-            if self.jitter_time:
-                self.parameters['geocent_time'] -= self.parameters['time_jitter']
+                d_inner_h_array=total_snrs.d_inner_h_array,
+                h_inner_h=total_snrs.optimal_snr_squared_array)
 
         elif self.calibration_marginalization:
             log_l = self.calibration_marginalized_likelihood(
-                d_inner_h_calibration_array=d_inner_h_array,
-                h_inner_h=optimal_snr_squared_array)
+                d_inner_h_calibration_array=total_snrs.d_inner_h_array,
+                h_inner_h=total_snrs.optimal_snr_squared_array)
 
         elif self.time_marginalization:
             log_l = self.time_marginalized_likelihood(
-                d_inner_h_tc_array=d_inner_h_array,
-                h_inner_h=optimal_snr_squared)
-            if self.jitter_time:
-                self.parameters['geocent_time'] -= self.parameters['time_jitter']
+                d_inner_h_tc_array=total_snrs.d_inner_h_array,
+                h_inner_h=total_snrs.optimal_snr_squared)
 
         elif self.distance_marginalization:
             log_l = self.distance_marginalized_likelihood(
-                d_inner_h=d_inner_h, h_inner_h=optimal_snr_squared)
+                d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared)
 
         elif self.phase_marginalization:
             log_l = self.phase_marginalized_likelihood(
-                d_inner_h=d_inner_h, h_inner_h=optimal_snr_squared)
+                d_inner_h=total_snrs.d_inner_h, h_inner_h=total_snrs.optimal_snr_squared)
 
         else:
-            log_l = np.real(d_inner_h) - optimal_snr_squared / 2
+            log_l = np.real(total_snrs.d_inner_h) - total_snrs.optimal_snr_squared / 2
 
-        return float(log_l.real)
+        return log_l
+
+    def compute_per_detector_log_likelihood(self):
+        waveform_polarizations = \
+            self.waveform_generator.frequency_domain_strain(self.parameters)
+
+        if self.time_marginalization and self.jitter_time:
+            self.parameters['geocent_time'] += self.parameters['time_jitter']
+
+        self.parameters.update(self.get_sky_frame_parameters())
+
+        for interferometer in self.interferometers:
+            per_detector_snr = self.calculate_snrs(
+                waveform_polarizations=waveform_polarizations,
+                interferometer=interferometer)
+
+            self.parameters['{}_log_likelihood'.format(interferometer.name)] = \
+                self.compute_log_likelihood_from_snrs(per_detector_snr)
+
+        if self.time_marginalization and self.jitter_time:
+            self.parameters['geocent_time'] -= self.parameters['time_jitter']
+
+        return self.parameters.copy()
 
     def generate_posterior_sample_from_marginalized_likelihood(self):
         """
@@ -777,32 +827,29 @@ class GravitationalWaveTransient(Likelihood):
             signal_polarizations = \
                 self.waveform_generator.frequency_domain_strain(self.parameters)
 
-        d_inner_h = 0.
-        optimal_snr_squared = 0.
-        complex_matched_filter_snr = 0.
-        d_inner_h_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
-        optimal_snr_squared_array = np.zeros(self.number_of_response_curves, dtype=np.complex128)
+        total_snrs = self._CalculatedSNRs(
+            d_inner_h=0., optimal_snr_squared=0., complex_matched_filter_snr=0.,
+            d_inner_h_array=np.zeros(self.number_of_response_curves, dtype=np.complex128),
+            optimal_snr_squared_array=np.zeros(self.number_of_response_curves, dtype=np.complex128))
 
         for interferometer in self.interferometers:
             per_detector_snr = self.calculate_snrs(
                 waveform_polarizations=signal_polarizations,
                 interferometer=interferometer)
 
-            d_inner_h += per_detector_snr.d_inner_h
-            optimal_snr_squared += np.real(per_detector_snr.optimal_snr_squared)
-            complex_matched_filter_snr += per_detector_snr.complex_matched_filter_snr
-            d_inner_h_array += per_detector_snr.d_inner_h_array
-            optimal_snr_squared_array += per_detector_snr.optimal_snr_squared_array
+            total_snrs += per_detector_snr
 
         if self.distance_marginalization:
             log_l_cal_array = self.distance_marginalized_likelihood(
-                d_inner_h=d_inner_h_array, h_inner_h=optimal_snr_squared_array)
+                d_inner_h=total_snrs.d_inner_h_array,
+                h_inner_h=total_snrs.optimal_snr_squared_array)
         elif self.phase_marginalization:
             log_l_cal_array = self.phase_marginalized_likelihood(
-                d_inner_h=d_inner_h_array,
-                h_inner_h=optimal_snr_squared_array)
+                d_inner_h=total_snrs.d_inner_h_array,
+                h_inner_h=total_snrs.optimal_snr_squared_array)
         else:
-            log_l_cal_array = np.real(d_inner_h_array - optimal_snr_squared_array / 2)
+            log_l_cal_array = \
+                np.real(total_snrs.d_inner_h_array - total_snrs.optimal_snr_squared_array / 2)
 
         return log_l_cal_array
 
diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py
index 9d0b7e1d154774a95469154918ab55601bf626a3..88e7234a20820d1eee151118b717a5d0c6f91cb8 100644
--- a/bilby/gw/likelihood/multiband.py
+++ b/bilby/gw/likelihood/multiband.py
@@ -8,6 +8,7 @@ from ...core.utils import (
     logger, speed_of_light, solar_mass, radius_of_earth,
     gravitational_constant, round_up_to_power_of_two
 )
+from ..prior import CBCPriorDict
 
 
 class MBGravitationalWaveTransient(GravitationalWaveTransient):
@@ -21,8 +22,9 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         A list of `bilby.detector.Interferometer` instances - contains the detector data and power spectral densities
     waveform_generator: `bilby.waveform_generator.WaveformGenerator`
         An object which computes the frequency-domain strain of the signal, given some set of parameters
-    reference_chirp_mass: float
-        A reference chirp mass for determining the frequency banding
+    reference_chirp_mass: float, optional
+        A reference chirp mass for determining the frequency banding. This is set to prior minimum of chirp mass if
+        not specified. Hence a CBCPriorDict object needs to be passed to priors when this parameter is not specified.
     highest_mode: int, optional
         The maximum magnetic number of gravitational-wave moments. Default is 2
     linear_interpolation: bool, optional
@@ -72,10 +74,11 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
 
     """
     def __init__(
-            self, interferometers, waveform_generator, reference_chirp_mass, highest_mode=2, linear_interpolation=True,
-            accuracy_factor=5, time_offset=None, delta_f_end=None, maximum_banding_frequency=None,
-            minimum_banding_duration=0., distance_marginalization=False, phase_marginalization=False, priors=None,
-            distance_marginalization_lookup_table=None, reference_frame="sky", time_reference="geocenter"
+            self, interferometers, waveform_generator, reference_chirp_mass=None, highest_mode=2,
+            linear_interpolation=True, accuracy_factor=5, time_offset=None, delta_f_end=None,
+            maximum_banding_frequency=None, minimum_banding_duration=0., distance_marginalization=False,
+            phase_marginalization=False, priors=None, distance_marginalization_lookup_table=None,
+            reference_frame="sky", time_reference="geocenter"
     ):
         super(MBGravitationalWaveTransient, self).__init__(
             interferometers=interferometers, waveform_generator=waveform_generator, priors=priors,
@@ -108,7 +111,24 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
         if isinstance(reference_chirp_mass, int) or isinstance(reference_chirp_mass, float):
             self._reference_chirp_mass = reference_chirp_mass
         else:
-            raise TypeError("reference_chirp_mass must be a number")
+            logger.info(
+                "No int or float number has been passed to reference_chirp_mass. "
+                "Checking prior minimum of chirp mass ..."
+            )
+            if not isinstance(self.priors, CBCPriorDict):
+                raise TypeError(
+                    f"priors: {self.priors} is not CBCPriorDict. Prior minimum of chirp mass can not be obtained."
+                )
+            self._reference_chirp_mass = self.priors.minimum_chirp_mass
+            if self._reference_chirp_mass is None:
+                raise Exception(
+                    "Prior minimum of chirp mass can not be determined as priors does not contain necessary mass "
+                    "parameters."
+                )
+            logger.info(
+                "reference_chirp_mass is automatically set to prior minimum of chirp mass: "
+                f"{self._reference_chirp_mass}."
+            )
 
     @property
     def highest_mode(self):
diff --git a/sampler_requirements.txt b/sampler_requirements.txt
index d6ed8e98ebce0797864ea463672eaf04ac600fd0..29f38dc1973f870ce1556b08889fc424bcb753b9 100644
--- a/sampler_requirements.txt
+++ b/sampler_requirements.txt
@@ -9,6 +9,6 @@ pymultinest
 kombine
 ultranest>=3.0.0
 dnest4
-nessai>=0.2.3
+nessai>=0.7.0
 schwimmbad
 zeus-mcmc>=2.3.0
diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py
index cbb084735ec50274b45d2dd629772c92d0d3daed..0cac7a45b24e9174336ed454e11908fd0e0e6555 100644
--- a/test/core/sampler/nessai_test.py
+++ b/test/core/sampler/nessai_test.py
@@ -21,9 +21,9 @@ class TestNessai(unittest.TestCase):
             plot=False,
             skip_import_verification=True,
             sampling_seed=150914,
-            npool=None,  # TODO: remove when support for nessai<0.7.0 is dropped
         )
         self.expected = self.sampler.default_kwargs
+        self.expected["n_pool"] = 1  # Because npool=1 by default
         self.expected['output'] = 'outdir/label_nessai/'
         self.expected['seed'] = 150914
 
@@ -48,28 +48,31 @@ class TestNessai(unittest.TestCase):
 
     def test_translate_kwargs_npool(self):
         expected = self.expected.copy()
-        expected["n_pool"] = None
+        expected["n_pool"] = 2
         for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
             del new_kwargs["n_pool"]
-            new_kwargs[equiv] = None
+            new_kwargs[equiv] = 2
             self.sampler.kwargs = new_kwargs
             self.assertDictEqual(expected, self.sampler.kwargs)
 
-    def test_translate_kwargs_seed(self):
-        assert self.expected["seed"] == 150914
+    def test_split_kwargs(self):
+        kwargs, run_kwargs = self.sampler.split_kwargs()
+        assert "save" not in run_kwargs
+        assert "plot" in run_kwargs
 
-    def test_npool_max_threads(self):
-        # TODO: remove when support for nessai<0.7.0 is dropped
+    def test_translate_kwargs_no_npool(self):
         expected = self.expected.copy()
-        expected["n_pool"] = None
-        expected["max_threads"] = 1
+        expected["n_pool"] = 3
         new_kwargs = self.sampler.kwargs.copy()
-        new_kwargs["n_pool"] = 1
-        new_kwargs["max_threads"] = 1
+        del new_kwargs["n_pool"]
+        self.sampler._npool = 3
         self.sampler.kwargs = new_kwargs
         self.assertDictEqual(expected, self.sampler.kwargs)
 
+    def test_translate_kwargs_seed(self):
+        assert self.expected["seed"] == 150914
+
     @patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
     def test_update_from_config_file(self):
         expected = self.expected.copy()
diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py
index cadeae1bd55e98cc578e4d6e08bfd6cd56c70bda..6cf05de66275debe90bb6069c04947507f7e1f3a 100644
--- a/test/gw/conversion_test.py
+++ b/test/gw/conversion_test.py
@@ -473,7 +473,7 @@ class TestGenerateAllParameters(unittest.TestCase):
             for key in expected:
                 self.assertIn(key, new_parameters)
 
-    def test_generate_bbh_paramters_with_likelihood(self):
+    def test_generate_bbh_parameters_with_likelihood(self):
         priors = bilby.gw.prior.BBHPriorDict()
         priors["geocent_time"] = bilby.core.prior.Uniform(0.4, 0.6)
         ifos = bilby.gw.detector.InterferometerList(["H1"])
@@ -489,9 +489,9 @@ class TestGenerateAllParameters(unittest.TestCase):
             time_marginalization=True,
             reference_frame="H1L1",
         )
-        self.parameters["zenith"] = 0
-        self.parameters["azimuth"] = 0
-        self.parameters["time_jitter"] = 0
+        self.parameters["zenith"] = 0.0
+        self.parameters["azimuth"] = 0.0
+        self.parameters["time_jitter"] = 0.0
         del self.parameters["ra"], self.parameters["dec"]
         self.parameters = pd.DataFrame(self.parameters, index=range(1))
         converted = bilby.gw.conversion.generate_all_bbh_parameters(
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index be103d1b0053e3028ce862fe074b74478d679056..8428ee0eab3a178f968584408ef522f7ff67a52c 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -2,7 +2,6 @@ import itertools
 import os
 import pytest
 import unittest
-from copy import deepcopy
 from itertools import product
 from parameterized import parameterized
 
@@ -1592,9 +1591,9 @@ class TestBBHLikelihoodSetUp(unittest.TestCase):
 
 class TestMBLikelihood(unittest.TestCase):
     def setUp(self):
-        duration = 16
-        fmin = 20.
-        sampling_frequency = 2048.
+        self.duration = 16
+        self.fmin = 20.
+        self.sampling_frequency = 2048.
         self.test_parameters = dict(
             chirp_mass=6.0,
             mass_ratio=0.5,
@@ -1613,18 +1612,18 @@ class TestMBLikelihood(unittest.TestCase):
             dec=-1.2
         )  # Network SNR is ~50
 
-        ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
+        self.ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
         np.random.seed(170817)
-        ifos.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=sampling_frequency, duration=duration,
-            start_time=self.test_parameters['geocent_time'] - duration + 2.
+        self.ifos.set_strain_data_from_power_spectral_densities(
+            sampling_frequency=self.sampling_frequency, duration=self.duration,
+            start_time=self.test_parameters['geocent_time'] - self.duration + 2.
         )
-        for ifo in ifos:
-            ifo.minimum_frequency = fmin
+        for ifo in self.ifos:
+            ifo.minimum_frequency = self.fmin
 
         spline_calibration_nodes = 10
         self.calibration_parameters = {}
-        for ifo in ifos:
+        for ifo in self.ifos:
             ifo.calibration_model = bilby.gw.calibration.CubicSpline(
                 prefix=f"recalib_{ifo.name}_",
                 minimum_frequency=ifo.minimum_frequency,
@@ -1640,143 +1639,168 @@ class TestMBLikelihood(unittest.TestCase):
                 self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \
                     np.random.normal(loc=0, scale=5 * np.pi / 180)
 
-        priors = bilby.gw.prior.BBHPriorDict()
-        priors.pop("mass_1")
-        priors.pop("mass_2")
-        priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5)
-        priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
-        priors["geocent_time"] = bilby.core.prior.Uniform(
+        self.priors = bilby.gw.prior.BBHPriorDict()
+        self.priors.pop("mass_1")
+        self.priors.pop("mass_2")
+        self.priors["chirp_mass"] = bilby.core.prior.Uniform(5.5, 6.5)
+        self.priors["mass_ratio"] = bilby.core.prior.Uniform(0.125, 1)
+        self.priors["geocent_time"] = bilby.core.prior.Uniform(
             self.test_parameters['geocent_time'] - 0.1,
             self.test_parameters['geocent_time'] + 0.1)
 
-        approximant_22 = "IMRPhenomD"
-        approximant_homs = "IMRPhenomHM"
-        non_mb_wfg_22 = bilby.gw.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
+    def tearDown(self):
+        del (
+            self.ifos,
+            self.priors
+        )
+
+    @parameterized.expand([
+        ("IMRPhenomD", True, 2, False, 1.5e-2),
+        ("IMRPhenomD", True, 2, True, 1.5e-2),
+        ("IMRPhenomD", False, 2, False, 5e-3),
+        ("IMRPhenomD", False, 2, True, 6e-3),
+        ("IMRPhenomHM", False, 4, False, 8e-4),
+        ("IMRPhenomHM", False, 4, True, 1e-3)
+    ])
+    def test_matches_original_likelihood(
+        self, approximant, linear_interpolation, highest_mode, add_cal_errors, tolerance
+    ):
+        """
+        Check if multi-band likelihood values match original likelihood values
+        """
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
             waveform_arguments=dict(
-                reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_22)
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        mb_wfg_22 = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
+        self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
+
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
             frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
             waveform_arguments=dict(
-                reference_frequency=fmin, approximant=approximant_22)
-        )
-        non_mb_wfg_homs = bilby.gw.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-            waveform_arguments=dict(
-                reference_frequency=fmin, minimum_frequency=fmin, approximant=approximant_homs)
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        mb_wfg_homs = bilby.gw.waveform_generator.WaveformGenerator(
-            duration=duration, sampling_frequency=sampling_frequency,
-            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
-            waveform_arguments=dict(
-                reference_frequency=fmin, approximant=approximant_homs)
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg
         )
-
-        ifos_22 = deepcopy(ifos)
-        ifos_22.inject_signal(
-            parameters=self.test_parameters, waveform_generator=non_mb_wfg_22
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.test_parameters['chirp_mass'],
+            priors=self.priors.copy(), linear_interpolation=linear_interpolation,
+            highest_mode=highest_mode
         )
-        ifos_homs = deepcopy(ifos)
-        ifos_homs.inject_signal(
-            parameters=self.test_parameters, waveform_generator=non_mb_wfg_homs
+        likelihood.parameters.update(self.test_parameters)
+        likelihood_mb.parameters.update(self.test_parameters)
+        if add_cal_errors:
+            likelihood.parameters.update(self.calibration_parameters)
+            likelihood_mb.parameters.update(self.calibration_parameters)
+        self.assertLess(
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()),
+            tolerance
         )
 
-        self.non_mb_22 = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=non_mb_wfg_22
-        )
-        self.non_mb_homs = bilby.gw.likelihood.GravitationalWaveTransient(
-            interferometers=ifos_homs, waveform_generator=non_mb_wfg_homs
+    def test_large_accuracy_factor(self):
+        """
+        Check if larger accuracy factor increases the accuracy.
+        """
+        approximant = "IMRPhenomD"
+        wfg = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
+        self.ifos.inject_signal(parameters=self.test_parameters, waveform_generator=wfg)
 
-        self.mb_22 = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
-            reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy()
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant=approximant
+            )
         )
-        self.mb_ifftfft_22 = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
-            reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), linear_interpolation=False
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg
         )
-        self.mb_homs = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_homs, waveform_generator=deepcopy(mb_wfg_homs),
+        likelihood_mb = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
             reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), linear_interpolation=False, highest_mode=4
+            priors=self.priors.copy(), accuracy_factor=5
         )
-        self.mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient(
-            interferometers=ifos_22, waveform_generator=deepcopy(mb_wfg_22),
+        likelihood_mb_more_accurate = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
             reference_chirp_mass=self.test_parameters['chirp_mass'],
-            priors=priors.copy(), accuracy_factor=50
-        )
-
-    def tearDown(self):
-        del (
-            self.non_mb_22,
-            self.non_mb_homs,
-            self.mb_22,
-            self.mb_ifftfft_22,
-            self.mb_homs,
-            self.mb_more_accurate
+            priors=self.priors.copy(), accuracy_factor=50
         )
-
-    @parameterized.expand([(False, ), (True, )])
-    def test_matches_non_mb(self, add_cal_errors):
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_22.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_22.parameters.update(self.calibration_parameters)
-            self.mb_22.parameters.update(self.calibration_parameters)
+        likelihood.parameters.update(self.test_parameters)
+        likelihood_mb.parameters.update(self.test_parameters)
+        likelihood_mb_more_accurate.parameters.update(self.test_parameters)
         self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()),
-            1.5e-2
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb_more_accurate.log_likelihood_ratio()),
+            abs(likelihood.log_likelihood_ratio() - likelihood_mb.log_likelihood_ratio()) / 2
         )
 
-    @parameterized.expand([(False, ), (True, )])
-    def test_ifft_fft(self, add_cal_errors):
+    def test_reference_chirp_mass_from_prior(self):
         """
-        Check if multi-banding likelihood with (h, h) computed with the
-        IFFT-FFT algorithm matches the original likelihood.
+        Check if reference chirp mass is automatically determined from prior if no number has been passed
         """
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_ifftfft_22.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_22.parameters.update(self.calibration_parameters)
-            self.mb_ifftfft_22.parameters.update(self.calibration_parameters)
-        self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_ifftfft_22.log_likelihood_ratio()),
-            6e-3
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
+        )
+        likelihood1 = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            reference_chirp_mass=self.priors["chirp_mass"].minimum,
+            priors=self.priors.copy()
+        )
+        likelihood2 = bilby.gw.likelihood.MBGravitationalWaveTransient(
+            interferometers=self.ifos, waveform_generator=wfg_mb,
+            priors=self.priors.copy()
         )
+        self.assertAlmostEqual(likelihood1.reference_chirp_mass, likelihood2.reference_chirp_mass)
 
-    @parameterized.expand([(False, ), (True, )])
-    def test_homs(self, add_cal_errors):
+    def test_no_reference_chirp_mass(self):
         """
-        Check if multi-banding likelihood matches the original likelihood for higher-order moments.
+        Check if an error is raised if either reference_chirp_mass or priors is not specified.
         """
-        self.non_mb_homs.parameters.update(self.test_parameters)
-        self.mb_homs.parameters.update(self.test_parameters)
-        if add_cal_errors:
-            self.non_mb_homs.parameters.update(self.calibration_parameters)
-            self.mb_homs.parameters.update(self.calibration_parameters)
-        self.assertLess(
-            abs(self.non_mb_homs.log_likelihood_ratio() - self.mb_homs.log_likelihood_ratio()),
-            1e-3
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
         )
+        with self.assertRaises(TypeError):
+            bilby.gw.likelihood.MBGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=wfg_mb
+            )
 
-    def test_large_accuracy_factor(self):
+    def test_cannot_determine_reference_chirp_mass(self):
         """
-        Check if larger accuracy factor increases the accuracy.
+        Check if an error is raised if priors does not contain necessary information to determine reference chirp mass
         """
-        self.non_mb_22.parameters.update(self.test_parameters)
-        self.mb_22.parameters.update(self.test_parameters)
-        self.mb_more_accurate.parameters.update(self.test_parameters)
-        self.assertLess(
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_more_accurate.log_likelihood_ratio()),
-            abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()) / 2
+        wfg_mb = bilby.gw.WaveformGenerator(
+            duration=self.duration, sampling_frequency=self.sampling_frequency,
+            frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
+            waveform_arguments=dict(
+                reference_frequency=self.fmin, approximant="IMRPhenomD"
+            )
         )
+        for key in ["chirp_mass", "mass_1", "mass_2"]:
+            if key in self.priors:
+                self.priors.pop(key)
+        with self.assertRaises(Exception):
+            bilby.gw.likelihood.MBGravitationalWaveTransient(
+                interferometers=self.ifos, waveform_generator=wfg_mb, priors=self.priors
+            )
 
 
 if __name__ == "__main__":
diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py
index cdf549ccd519737916e97422f00b5de92feaf0ea..f9304971da4b5d91bc3475a1b3280576f96ef525 100644
--- a/test/integration/sampler_run_test.py
+++ b/test/integration/sampler_run_test.py
@@ -41,9 +41,8 @@ _sampler_kwargs = dict(
     kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
     nessai=dict(
         nlive=100,
-        poolsize=1000,
-        max_iteration=1000,
-        max_threads=3,
+        poolsize=100,
+        max_iteration=500,
     ),
     nestle=dict(nlive=100),
     ptemcee=dict(
@@ -159,11 +158,6 @@ class TestRunningSamplers(unittest.TestCase):
             pytest.skip(f"{sampler} cannot be parallelized")
         if sys.version_info.minor == 8 and sampler.lower == "cpnest":
             pytest.skip("Pool interrupting broken for cpnest with py3.8")
-        if sampler.lower() == "nessai" and pool_size > 1:
-            pytest.skip(
-                "Interrupting with a pool is failing in pytest. "
-                "Likely due to interactions with the signal handling in nessai."
-            )
         pid = os.getpid()
         print(sampler)