diff --git a/.coveragerc b/.coveragerc
index c051155b5f2b882364426a8f85bb5633b7afcd04..7d64384d997739ae73dc950ed2fd722d17f73322 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -3,3 +3,4 @@ omit =
     test/integration/example_test.py
     test/integration/noise_realisation_test.py
     test/integration/other_test.py
+    bilby/_version.py
diff --git a/.gitignore b/.gitignore
index 88717818866b233ab49b644f24d116aa61efea91..48e4c01a1387a87c5af4a7381f1f5d157171e3c0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,4 @@ MANIFEST
 *.ipynb_checkpoints
 outdir/*
 .idea/*
+bilby/_version.py
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 7663fbac2e98490a65363d7fb217fa678f16ec65..5fa17f7bb4da803074c3672413c2e935dc84da6d 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -135,11 +135,12 @@ python-3.9:
   after_script:
     - coverage html
     - coverage xml
+  coverage: '/(?i)total.*? (100(?:\.0+)?\%|[1-9]?\d(?:\.\d+)?\%)$/'
   artifacts:
     reports:
       coverage_report:
         coverage_format: cobertura
-        path: coverage/cobertura-coverage.xml
+        path: coverage.xml
     paths:
       - htmlcov/
     expire_in: 30 days
@@ -153,9 +154,10 @@ python-3.10:
   stage: test
   script:
     - python -m pip install .
+    - python -m pip install schwimmbad
     - python -m pip list installed
 
-    - pytest test/integration/sampler_run_test.py --durations 10
+    - pytest test/integration/sampler_run_test.py --durations 10 -v
 
 python-3.8-samplers:
   <<: *test-sampler
@@ -282,7 +284,7 @@ pypi-release:
     TWINE_USERNAME: $PYPI_USERNAME
     TWINE_PASSWORD: $PYPI_PASSWORD
   before_script:
-    - pip install twine
+    - pip install twine setuptools_scm
     - python setup.py sdist
   script:
     - twine upload dist/*
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 14215c6f183db9d846f38f82067ecd0f7878f1f4..a78a604a57217f0d579213ef06be28307664da50 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -9,7 +9,7 @@ repos:
     hooks:
       - id: black
         language_version: python3
-        files: '(^bilby/bilby_mcmc/|^examples/core_examples/|examples/gw_examples/data_examples)'
+        files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
 -   repo: https://github.com/codespell-project/codespell
     rev: v2.1.0
     hooks:
@@ -20,7 +20,7 @@ repos:
     hooks:
     -   id: isort # sort imports alphabetically and separates import into sections
         args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
-        files: '(^bilby/bilby_mcmc/|^examples/core_examples/examples/gw_examples/data_examples)'
+        files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
 -   repo: https://github.com/datarootsio/databooks
     rev: 0.1.14
     hooks:
diff --git a/AUTHORS.md b/AUTHORS.md
index 32f41f250997c53eaa77e2217a7a7326c3be4081..819c6e71d245f909f12f3c40fe8be87c76412893 100644
--- a/AUTHORS.md
+++ b/AUTHORS.md
@@ -24,6 +24,7 @@ Eric Thrane
 Ethan Payne
 Francisco Javier Hernandez
 Gregory Ashton
+Hank Hua
 Hector Estelles
 Ignacio Magaña Hernandez
 Isobel Marguarethe Romero-Shaw
@@ -34,6 +35,7 @@ Jeremy G Baier
 John Veitch
 Joshua Brandt
 Josh Willis
+Karl Wette
 Katerina Chatziioannou
 Kaylee de Soto
 Khun Sang Phukon
@@ -59,6 +61,7 @@ Paul Easter
 Paul Lasky
 Philip Relton
 Rhys Green
+Richard Udall
 Rico Lo
 Roberto Cotesta
 Rory Smith
@@ -75,6 +78,9 @@ Stephen R Green
 Sumeet Kulkarni
 Sylvia Biscoveanu
 Tathagata Ghosh
+Tomasz Baka
+Will M. Farr
 Virginia d'Emilio
 Vivien Raymond
 Ka-Lok Lo
+Isaac Legred
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 73cdbefb1a01c4f877ee657fa6f3c4fc668162e4..635c69e0a88621f116ee60674c61205586b3547c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,56 @@
 # All notable changes will be documented in this file
 
+## [1.2.1] 2022-09-05
+Version 1.2.1 release of Bilby
+
+This release contains a few bug fixes following 1.2.0.
+
+### Changes
+- Improve how sampling seed is handled across samplers (!1134)
+- Make sure labels are included when evidences are in corner plot legend (!1135)
+- Remove calls to `getargspec` (!1136)
+- Make sure parameter reconstruction cache is not mangled when reading (!1126)
+- Enable the constant uncertainty calibration spline to have a specifiable boundary condition (!1137)
+- Fix a bug in checkpointing for `bilby_mcmc` (!1141)
+- Fix the `LALCBCWaveformGenerator` (!1140)
+- Switch to automatic versioning with `setuptools_scm` (!1125)
+- Improve the stability of the multivariate normal prior (!1142)
+- Extend mass conversions to include source-frame parameters (!1131)
+- Fix prior ranges for GW150914 example (!1129)
+
+## [1.2.0] 2022-08-15
+Version 1.2.0 release of Bilby
+
+This is the first release that drops support for `Python<3.8`.
+
+This release involves major refactoring, especially of the sampler implementations.
+
+Additionally, there are a range of improvements to how information is passed
+with multiprocessing.
+
+### Added
+- Time marginalized ROQ likelihood (!1040)
+- Multiple and multi-banded ROQ likelihood (!1093)
+- Gaussian process likelihoods (!1086)
+- `CBCWaveformGenerator` added with CBC specific defaults (!1080)
+
+### Changes
+- Fixes and improvements to multi-processing (!1084, !1043, !1096)
+- Major refactoring of sampler implementations (!1043)
+- Fixes for reading/writing priors (!1103, !1127, !1128)
+- Fixes/updates to exmample scripts (!1050, !1031, !1076, !1081, !1074)
+- Fixes to calibration correction in GW likelihoods (!1114, !1120, !1119)
+
+### Deprecated/removed
+- Require `Python>=3.8`
+- Require `astropy>=5`
+- `bilby.core.utils.conversion.gps_time_to_gmst`
+- `bilby.core.utils.spherical_to_cartesian`
+- `bilby.core.utils.progress`
+- Deepdish IO for `Result`, `Interferometer`, and `InterferometerList`
+
 ## [1.1.5] 2022-01-14
-Version 1.1.5 release of bilby
+Version 1.1.5 release of Bilby
 
 ### Added
 - Option to enforce that a GW signal fits into the segment duration (!1041)
diff --git a/MANIFEST.in b/MANIFEST.in
index d6ac4a64451d950e361c2c2c860632e80f8d4791..331a31917140aef951df8fcfb55384beab56fcc6 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -5,4 +5,5 @@ include gw_requirements.txt
 include mcmc_requirements.txt
 include optional_requirements.txt
 include sampler_requirements.txt
+include bilby/_version.py
 recursive-include test *.py *.prior
diff --git a/bilby/__init__.py b/bilby/__init__.py
index 04a29e1bc27008ffed916a8f8f171bc6a2315c7c..092f05a1a52f146b1375ccb39869c14ae6763149 100644
--- a/bilby/__init__.py
+++ b/bilby/__init__.py
@@ -24,7 +24,10 @@ from .core import utils, likelihood, prior, result, sampler
 from .core.sampler import run_sampler
 from .core.likelihood import Likelihood
 
-__version__ = utils.get_version_information()
+try:
+    from ._version import version as __version__
+except ModuleNotFoundError:  # development mode
+    __version__ = 'unknown'
 
 
 if sys.version_info < (3,):
diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index 9f7925bf340eafcb07d8e9903e3c612b1a03ee4f..6485ce1a00bfddb52e0d5e11d608ba20044a6ef8 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -1,6 +1,5 @@
 import datetime
 import os
-import signal
 import time
 from collections import Counter
 
@@ -8,7 +7,13 @@ import numpy as np
 import pandas as pd
 
 from ..core.result import rejection_sample
-from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError
+from ..core.sampler.base_sampler import (
+    MCMCSampler,
+    ResumeError,
+    SamplerError,
+    _sampling_convenience_dump,
+    signal_wrapper,
+)
 from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
 from . import proposals
 from .chain import Chain, Sample
@@ -131,7 +136,6 @@ class Bilby_MCMC(MCMCSampler):
         autocorr_c=5,
         L1steps=100,
         L2steps=3,
-        npool=1,
         printdt=60,
         min_tau=1,
         proposal_cycle="default",
@@ -154,7 +158,7 @@ class Bilby_MCMC(MCMCSampler):
         diagnostic=False,
         resume=True,
         exit_code=130,
-        verbose=False,
+        verbose=True,
         **kwargs,
     ):
 
@@ -172,7 +176,6 @@ class Bilby_MCMC(MCMCSampler):
         self.check_point_plot = check_point_plot
         self.diagnostic = diagnostic
         self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
-        self.npool = self.kwargs["npool"]
         self.L1steps = self.kwargs["L1steps"]
         self.L2steps = self.kwargs["L2steps"]
         self.pt_inputs = ParallelTemperingInputs(
@@ -194,22 +197,12 @@ class Bilby_MCMC(MCMCSampler):
         self.verify_configuration()
         self.verbose = verbose
 
-        try:
-            signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-            signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-            signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-        except AttributeError:
-            logger.debug(
-                "Setting signal attributes unavailable on this system. "
-                "This is likely the case if you are running on a Windows machine"
-                " and is no further concern."
-            )
-
     def verify_configuration(self):
         if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
             logger.warning("Burn-in inefficiency fraction greater than 10%")
 
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "printdt" not in kwargs:
             for equiv in ["print_dt", "print_update"]:
                 if equiv in kwargs:
@@ -223,6 +216,7 @@ class Bilby_MCMC(MCMCSampler):
     def target_nsamples(self):
         return self.kwargs["target_nsamples"]
 
+    @signal_wrapper
     def run_sampler(self):
         self._setup_pool()
         self.setup_chain_set()
@@ -377,31 +371,12 @@ class Bilby_MCMC(MCMCSampler):
             f"setup:\n{self.get_setup_string()}"
         )
 
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """
-        Make sure that if a pool of jobs is running only the parent tries to
-        checkpoint and exit. Only the parent has a 'pool' attribute.
-        """
-        if self.npool == 1 or getattr(self, "pool", None) is not None:
-            if signum == 14:
-                logger.info(
-                    "Run interrupted by alarm signal {}: checkpoint and exit on {}".format(
-                        signum, self.exit_code
-                    )
-                )
-            else:
-                logger.info(
-                    "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                        signum, self.exit_code
-                    )
-                )
-            self.write_current_state()
-            self._close_pool()
-            os._exit(self.exit_code)
-
     def write_current_state(self):
         import dill
 
+        if not hasattr(self, "ptsampler"):
+            logger.debug("Attempted checkpoint before initialization")
+            return
         logger.debug("Check point")
         check_directory_exists_and_if_not_mkdir(self.outdir)
 
@@ -534,39 +509,6 @@ class Bilby_MCMC(MCMCSampler):
                         all_samples=ptsampler.samples,
                     )
 
-    def _setup_pool(self):
-        if self.npool > 1:
-            logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
-            import multiprocessing
-
-            self.pool = multiprocessing.Pool(
-                processes=self.npool,
-                initializer=_initialize_global_variables,
-                initargs=(
-                    self.likelihood,
-                    self.priors,
-                    self._search_parameter_keys,
-                    self.use_ratio,
-                ),
-            )
-        else:
-            self.pool = None
-
-        _initialize_global_variables(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            search_parameter_keys=self._search_parameter_keys,
-            use_ratio=self.use_ratio,
-        )
-
-    def _close_pool(self):
-        if getattr(self, "pool", None) is not None:
-            logger.info("Starting to close worker pool.")
-            self.pool.close()
-            self.pool.join()
-            self.pool = None
-            logger.info("Finished closing worker pool.")
-
 
 class BilbyPTMCMCSampler(object):
     def __init__(
@@ -579,7 +521,6 @@ class BilbyPTMCMCSampler(object):
         use_ratio,
         evidence_method,
     ):
-
         self.set_pt_inputs(pt_inputs)
         self.use_ratio = use_ratio
         self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
@@ -597,7 +538,7 @@ class BilbyPTMCMCSampler(object):
 
         self._nsamples_dict = {}
         self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
-            _priors
+            _sampling_convenience_dump.priors
         )
         self.sampling_time = 0
         self.ln_z_dict = dict()
@@ -612,7 +553,7 @@ class BilbyPTMCMCSampler(object):
         elif pt_inputs.Tmax is not None:
             betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
         elif pt_inputs.Tmax_from_SNR is not None:
-            ndim = len(_priors.non_fixed_keys)
+            ndim = len(_sampling_convenience_dump.priors.non_fixed_keys)
             target_hot_likelihood = ndim / 2
             Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood)
             betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
@@ -1140,12 +1081,14 @@ class BilbyMCMCSampler(object):
         self.Eindex = Eindex
         self.use_ratio = use_ratio
 
-        self.parameters = _priors.non_fixed_keys
+        self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
         self.ndim = len(self.parameters)
 
-        full_sample_dict = _priors.sample()
+        full_sample_dict = _sampling_convenience_dump.priors.sample()
         initial_sample = {
-            k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys
+            k: v
+            for k, v in full_sample_dict.items()
+            if k in _sampling_convenience_dump.priors.non_fixed_keys
         }
         initial_sample = Sample(initial_sample)
         initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
@@ -1168,7 +1111,10 @@ class BilbyMCMCSampler(object):
                 warn = False
 
             self.proposal_cycle = proposals.get_proposal_cycle(
-                proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn
+                proposal_cycle,
+                _sampling_convenience_dump.priors,
+                L1steps=self.chain.L1steps,
+                warn=warn,
             )
         elif isinstance(proposal_cycle, proposals.ProposalCycle):
             self.proposal_cycle = proposal_cycle
@@ -1185,17 +1131,17 @@ class BilbyMCMCSampler(object):
         self.stop_after_convergence = convergence_inputs.stop_after_convergence
 
     def log_likelihood(self, sample):
-        _likelihood.parameters.update(sample.sample_dict)
+        _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict)
 
         if self.use_ratio:
-            logl = _likelihood.log_likelihood_ratio()
+            logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio()
         else:
-            logl = _likelihood.log_likelihood()
+            logl = _sampling_convenience_dump.likelihood.log_likelihood()
 
         return logl
 
     def log_prior(self, sample):
-        return _priors.ln_prob(sample.parameter_only_dict)
+        return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict)
 
     def accept_proposal(self, prop, proposal):
         self.chain.append(prop)
@@ -1293,8 +1239,10 @@ class BilbyMCMCSampler(object):
         zerotemp_logl = hot_samples[LOGLKEY]
 
         # Revert to true likelihood if needed
-        if _use_ratio:
-            zerotemp_logl += _likelihood.noise_log_likelihood()
+        if _sampling_convenience_dump.use_ratio:
+            zerotemp_logl += (
+                _sampling_convenience_dump.likelihood.noise_log_likelihood()
+            )
 
         # Calculate normalised weights
         log_weights = (1 - beta) * zerotemp_logl
@@ -1322,29 +1270,3 @@ class BilbyMCMCSampler(object):
 def call_step(sampler):
     sampler = sampler.step()
     return sampler
-
-
-_likelihood = None
-_priors = None
-_search_parameter_keys = None
-_use_ratio = False
-
-
-def _initialize_global_variables(
-    likelihood,
-    priors,
-    search_parameter_keys,
-    use_ratio,
-):
-    """
-    Store a global copy of the likelihood, priors, and search keys for
-    multiprocessing.
-    """
-    global _likelihood
-    global _priors
-    global _search_parameter_keys
-    global _use_ratio
-    _likelihood = likelihood
-    _priors = priors
-    _search_parameter_keys = search_parameter_keys
-    _use_ratio = use_ratio
diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py
index ea8607545c59126ab7e5c4d6e83b8c5ba7ba83d8..e11f7862ad7360737760eaa272cd6157f85d8ba6 100644
--- a/bilby/core/prior/base.py
+++ b/bilby/core/prior/base.py
@@ -214,10 +214,13 @@ class Prior(object):
 
         """
         prior_name = self.__class__.__name__
+        prior_module = self.__class__.__module__
         instantiation_dict = self.get_instantiation_dict()
-        args = ', '.join(['{}={}'.format(key, repr(instantiation_dict[key]))
-                          for key in instantiation_dict])
-        return "{}({})".format(prior_name, args)
+        args = ', '.join([f'{key}={repr(instantiation_dict[key])}' for key in instantiation_dict])
+        if "bilby.core.prior" in prior_module:
+            return f"{prior_name}({args})"
+        else:
+            return f"{prior_module}.{prior_name}({args})"
 
     @property
     def _repr_dict(self):
diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py
index 107822828c52ed32111a9462c1d8f4325b143719..c4dcc36827fd844853095dcbbcafd69f2585c40a 100644
--- a/bilby/core/prior/conditional.py
+++ b/bilby/core/prior/conditional.py
@@ -371,7 +371,7 @@ class DirichletElement(ConditionalBeta):
         self._required_variables = [
             label + str(ii) for ii in range(order)
         ]
-        self.__class__.__name__ = 'Dirichlet'
+        self.__class__.__name__ = 'DirichletElement'
 
     def dirichlet_condition(self, reference_parms, **kwargs):
         remaining = 1 - sum(
diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index 85e8ac4ef75c8808fd20f1b175fd5214364deb1a..d25ca6487560bbb946b8930ba8c625e9d0f2d766 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -824,6 +824,8 @@ class ConditionalPriorDict(PriorDict):
         =======
         list: List of floats containing the rescaled sample
         """
+        from matplotlib.cbook import flatten
+
         keys = list(keys)
         theta = list(theta)
         self._check_resolved()
@@ -836,7 +838,7 @@ class ConditionalPriorDict(PriorDict):
                 theta[index], **self.get_required_variables(key)
             )
             self[key].least_recently_sampled = result[key]
-        return [result[key] for key in keys]
+        return list(flatten([result[key] for key in keys]))
 
     def _update_rescale_keys(self, keys):
         if not keys == self._least_recently_rescaled_keys:
diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py
index 742e7920b067877ce4d8aa3d2aa0f862db1dee4f..7f35b8716526e8811571409c2aba6d476a9958cc 100644
--- a/bilby/core/prior/joint.py
+++ b/bilby/core/prior/joint.py
@@ -43,7 +43,7 @@ class BaseJointPriorDist(object):
                 if isinstance(bounds, (list, tuple, np.ndarray)):
                     if len(bound) != 2:
                         raise ValueError(
-                            "Bounds must contain an upper and " "lower value."
+                            "Bounds must contain an upper and lower value."
                         )
                     else:
                         if bound[1] <= bound[0]:
@@ -382,6 +382,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
         self.covs = []
         self.corrcoefs = []
         self.sigmas = []
+        self.logprodsigmas = []   # log of product of sigmas, needed for "standard" multivariate normal
         self.weights = []
         self.eigvalues = []
         self.eigvectors = []
@@ -399,7 +400,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                 if len(np.shape(sigmas)) == 1:
                     sigmas = [sigmas]
                 elif len(np.shape(sigmas)) == 0:
-                    raise ValueError("Must supply a list of standard " "deviations")
+                    raise ValueError("Must supply a list of standard deviations")
             if covs is not None:
                 if isinstance(covs, np.ndarray):
                     covs = [covs]
@@ -421,7 +422,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                             "List of correlation coefficients the wrong shape"
                         )
                 elif not isinstance(corrcoefs, list):
-                    raise TypeError("Must pass a list of correlation " "coefficients")
+                    raise TypeError("Must pass a list of correlation coefficients")
             if weights is not None:
                 if isinstance(weights, (int, float)):
                     weights = [weights]
@@ -489,7 +490,7 @@ class MultivariateGaussianDist(BaseJointPriorDist):
 
             if len(self.corrcoefs[-1].shape) != 2:
                 raise ValueError(
-                    "Correlation coefficient matrix must be a 2d " "array."
+                    "Correlation coefficient matrix must be a 2d array."
                 )
 
             if (
@@ -497,16 +498,16 @@ class MultivariateGaussianDist(BaseJointPriorDist):
                 or self.corrcoefs[-1].shape[0] != self.num_vars
             ):
                 raise ValueError(
-                    "Correlation coefficient matrix shape is " "inconsistent"
+                    "Correlation coefficient matrix shape is inconsistent"
                 )
 
             # check matrix is symmetric
             if not np.allclose(self.corrcoefs[-1], self.corrcoefs[-1].T):
-                raise ValueError("Correlation coefficient matrix is not " "symmetric")
+                raise ValueError("Correlation coefficient matrix is not symmetric")
 
             # check diagonal is all ones
             if not np.all(np.diag(self.corrcoefs[-1]) == 1.0):
-                raise ValueError("Correlation coefficient matrix is not" "correct")
+                raise ValueError("Correlation coefficient matrix is not correct")
 
             try:
                 self.sigmas.append(list(sigmas))  # standard deviations
@@ -528,6 +529,9 @@ class MultivariateGaussianDist(BaseJointPriorDist):
             self.covs.append(np.eye(self.num_vars))
             self.sigmas.append(np.ones(self.num_vars))
 
+        # compute log of product of sigmas, needed for "standard" multivariate normal
+        self.logprodsigmas.append(np.log(np.prod(self.sigmas[-1])))
+
         # get eigen values and vectors
         try:
             evals, evecs = np.linalg.eig(self.corrcoefs[-1])
@@ -535,13 +539,13 @@ class MultivariateGaussianDist(BaseJointPriorDist):
             self.eigvectors.append(evecs)
         except Exception as e:
             raise RuntimeError(
-                "Problem getting eigenvalues and vectors: " "{}".format(e)
+                "Problem getting eigenvalues and vectors: {}".format(e)
             )
 
         # check eigenvalues are positive
         if np.any(self.eigvalues[-1] <= 0.0):
             raise ValueError(
-                "Correlation coefficient matrix is not positive " "definite"
+                "Correlation coefficient matrix is not positive definite"
             )
         self.sqeigvalues.append(np.sqrt(self.eigvalues[-1]))
 
@@ -557,9 +561,16 @@ class MultivariateGaussianDist(BaseJointPriorDist):
         # add the mode
         self.nmodes += 1
 
-        # add multivariate Gaussian
+        # add "standard" multivariate normal distribution
+        # - when the typical scales of the parameters are very different,
+        #   multivariate_normal() may complain that the covariance matrix is singular
+        # - instead pass zero means and correlation matrix instead of covariance matrix
+        #   to get the equivalent of a standard normal distribution in higher dimensions
+        # - this modifies the multivariate normal PDF as follows:
+        #     multivariate_normal(mean=mus, cov=cov).logpdf(x)
+        #     = multivariate_normal(mean=0, cov=corrcoefs).logpdf((x - mus)/sigmas) - logprodsigmas
         self.mvn.append(
-            scipy.stats.multivariate_normal(mean=self.mus[-1], cov=self.covs[-1])
+            scipy.stats.multivariate_normal(mean=np.zeros(self.num_vars), cov=self.corrcoefs[-1])
         )
 
     def _rescale(self, samp, **kwargs):
@@ -630,7 +641,9 @@ class MultivariateGaussianDist(BaseJointPriorDist):
         for j in range(samp.shape[0]):
             # loop over the modes and sum the probabilities
             for i in range(self.nmodes):
-                lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(samp[j]))
+                # self.mvn[i] is a "standard" multivariate normal distribution; see add_mode()
+                z = (samp[j] - self.mus[i]) / self.sigmas[i]
+                lnprob[j] = np.logaddexp(lnprob[j], self.mvn[i].logpdf(z) - self.logprodsigmas[i])
 
         # set out-of-bounds values to -inf
         lnprob[outbounds] = -np.inf
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 221420c90cc375a673ffb327a4a58be341b2c2b2..bef3f11af4c6ac632b31d23e693668340a29a582 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -298,7 +298,7 @@ def reweight(result, label=None, new_likelihood=None, new_prior=None,
 
     if conversion_function is not None:
         data_frame = result.posterior
-        if "npool" in inspect.getargspec(conversion_function).args:
+        if "npool" in inspect.signature(conversion_function).parameters:
             data_frame = conversion_function(data_frame, new_likelihood, new_prior, npool=npool)
         else:
             data_frame = conversion_function(data_frame, new_likelihood, new_prior)
@@ -1389,7 +1389,7 @@ class Result(object):
             data_frame['log_prior'] = self.log_prior_evaluations
 
         if conversion_function is not None:
-            if "npool" in inspect.getargspec(conversion_function).args:
+            if "npool" in inspect.signature(conversion_function).parameters:
                 data_frame = conversion_function(data_frame, likelihood, priors, npool=npool)
             else:
                 data_frame = conversion_function(data_frame, likelihood, priors)
@@ -1435,8 +1435,11 @@ class Result(object):
         if keys is None:
             keys = self.search_parameter_keys
         if self.injection_parameters is None:
-            raise(TypeError, "Result object has no 'injection_parameters'. "
-                             "Cannot compute credible levels.")
+            raise (
+                TypeError,
+                "Result object has no 'injection_parameters'. "
+                "Cannot compute credible levels."
+            )
         credible_levels = {key: self.get_injection_credible_level(key, weights=weights)
                            for key in keys
                            if isinstance(self.injection_parameters.get(key, None), float)}
@@ -1462,8 +1465,11 @@ class Result(object):
         float: credible level
         """
         if self.injection_parameters is None:
-            raise(TypeError, "Result object has no 'injection_parameters'. "
-                             "Cannot copmute credible levels.")
+            raise (
+                TypeError,
+                "Result object has no 'injection_parameters'. "
+                "Cannot copmute credible levels."
+            )
 
         if weights is None:
             weights = np.ones(len(self.posterior))
@@ -1933,12 +1939,17 @@ def plot_multiple(results, filename=None, labels=None, colours=None,
 
     if evidences:
         if np.isnan(results[0].log_bayes_factor):
-            template = r' $\mathrm{{ln}}(Z)={lnz:1.3g}$'
+            template = r'{label} $\mathrm{{ln}}(Z)={lnz:1.3g}$'
         else:
-            template = r' $\mathrm{{ln}}(B)={lnbf:1.3g}$'
-        labels = [template.format(lnz=result.log_evidence,
-                                  lnbf=result.log_bayes_factor)
-                  for ii, result in enumerate(results)]
+            template = r'{label} $\mathrm{{ln}}(B)={lnbf:1.3g}$'
+        labels = [
+            template.format(
+                label=label,
+                lnz=result.log_evidence,
+                lnbf=result.log_bayes_factor,
+            )
+            for label, result in zip(labels, results)
+        ]
 
     axes = fig.get_axes()
     ndim = int(np.sqrt(len(axes)))
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index bd259f97a74c4784ecf4cf8b5fe2c742e7bc79b0..ce9422f85e6649c74c57b16997678a681b247841 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -1,29 +1,30 @@
+import datetime
 import inspect
 import sys
-import datetime
 
 import bilby
-from ..utils import command_line_args, logger, loaded_modules_dict
-from ..prior import PriorDict, DeltaFunction
+from bilby.bilby_mcmc import Bilby_MCMC
+
+from ..prior import DeltaFunction, PriorDict
+from ..utils import command_line_args, loaded_modules_dict, logger
+from . import proposal
 from .base_sampler import Sampler, SamplingMarginalisedParameterError
 from .cpnest import Cpnest
+from .dnest4 import DNest4
 from .dynamic_dynesty import DynamicDynesty
 from .dynesty import Dynesty
 from .emcee import Emcee
+from .fake_sampler import FakeSampler
 from .kombine import Kombine
 from .nessai import Nessai
 from .nestle import Nestle
 from .polychord import PyPolyChord
 from .ptemcee import Ptemcee
 from .ptmcmc import PTMCMCSampler
-from .pymc3 import Pymc3
+from .pymc import Pymc
 from .pymultinest import Pymultinest
 from .ultranest import Ultranest
-from .fake_sampler import FakeSampler
-from .dnest4 import DNest4
 from .zeus import Zeus
-from bilby.bilby_mcmc import Bilby_MCMC
-from . import proposal
 
 IMPLEMENTED_SAMPLERS = {
     "bilby_mcmc": Bilby_MCMC,
@@ -37,7 +38,7 @@ IMPLEMENTED_SAMPLERS = {
     "nestle": Nestle,
     "ptemcee": Ptemcee,
     "ptmcmcsampler": PTMCMCSampler,
-    "pymc3": Pymc3,
+    "pymc": Pymc,
     "pymultinest": Pymultinest,
     "pypolychord": PyPolyChord,
     "ultranest": Ultranest,
@@ -49,7 +50,7 @@ if command_line_args.sampler_help:
     sampler = command_line_args.sampler_help
     if sampler in IMPLEMENTED_SAMPLERS:
         sampler_class = IMPLEMENTED_SAMPLERS[sampler]
-        print('Help for sampler "{}":'.format(sampler))
+        print(f'Help for sampler "{sampler}":')
         print(sampler_class.__doc__)
     else:
         if sampler == "None":
@@ -58,8 +59,8 @@ if command_line_args.sampler_help:
                 "the name of the sampler"
             )
         else:
-            print("Requested sampler {} not implemented".format(sampler))
-        print("Available samplers = {}".format(IMPLEMENTED_SAMPLERS))
+            print(f"Requested sampler {sampler} not implemented")
+        print(f"Available samplers = {IMPLEMENTED_SAMPLERS}")
 
     sys.exit()
 
@@ -81,7 +82,7 @@ def run_sampler(
     gzip=False,
     result_class=None,
     npool=1,
-    **kwargs
+    **kwargs,
 ):
     """
     The primary interface to easy parameter estimation
@@ -144,9 +145,7 @@ def run_sampler(
         An object containing the results
     """
 
-    logger.info(
-        "Running for label '{}', output will be saved to '{}'".format(label, outdir)
-    )
+    logger.info(f"Running for label '{label}', output will be saved to '{outdir}'")
 
     if clean:
         command_line_args.clean = clean
@@ -174,7 +173,7 @@ def run_sampler(
         meta_data = dict()
     likelihood.label = label
     likelihood.outdir = outdir
-    meta_data['likelihood'] = likelihood.meta_data
+    meta_data["likelihood"] = likelihood.meta_data
     meta_data["loaded_modules"] = loaded_modules_dict()
 
     if command_line_args.bilby_zero_likelihood_mode:
@@ -198,11 +197,11 @@ def run_sampler(
                 plot=plot,
                 result_class=result_class,
                 npool=npool,
-                **kwargs
+                **kwargs,
             )
         else:
             print(IMPLEMENTED_SAMPLERS)
-            raise ValueError("Sampler {} not yet implemented".format(sampler))
+            raise ValueError(f"Sampler {sampler} not yet implemented")
     elif inspect.isclass(sampler):
         sampler = sampler.__init__(
             likelihood,
@@ -214,12 +213,12 @@ def run_sampler(
             injection_parameters=injection_parameters,
             meta_data=meta_data,
             npool=npool,
-            **kwargs
+            **kwargs,
         )
     else:
         raise ValueError(
             "Provided sampler should be a Sampler object or name of a known "
-            "sampler: {}.".format(", ".join(IMPLEMENTED_SAMPLERS.keys()))
+            f"sampler: {', '.join(IMPLEMENTED_SAMPLERS.keys())}."
         )
 
     if sampler.cached_result:
@@ -240,23 +239,22 @@ def run_sampler(
         elif isinstance(result.sampling_time, (float, int)):
             result.sampling_time = datetime.timedelta(result.sampling_time)
 
-        logger.info('Sampling time: {}'.format(result.sampling_time))
+        logger.info(f"Sampling time: {result.sampling_time}")
         # Convert sampling time into seconds
         result.sampling_time = result.sampling_time.total_seconds()
 
         if sampler.use_ratio:
             result.log_noise_evidence = likelihood.noise_log_likelihood()
             result.log_bayes_factor = result.log_evidence
-            result.log_evidence = \
-                result.log_bayes_factor + result.log_noise_evidence
+            result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
         else:
             result.log_noise_evidence = likelihood.noise_log_likelihood()
-            result.log_bayes_factor = \
-                result.log_evidence - result.log_noise_evidence
+            result.log_bayes_factor = result.log_evidence - result.log_noise_evidence
 
         if None not in [result.injection_parameters, conversion_function]:
             result.injection_parameters = conversion_function(
-                result.injection_parameters)
+                result.injection_parameters
+            )
 
         # Initial save of the sampler in case of failure in samples_to_posterior
         if save:
@@ -267,9 +265,12 @@ def run_sampler(
 
     # Check if the posterior has already been created
     if getattr(result, "_posterior", None) is None:
-        result.samples_to_posterior(likelihood=likelihood, priors=result.priors,
-                                    conversion_function=conversion_function,
-                                    npool=npool)
+        result.samples_to_posterior(
+            likelihood=likelihood,
+            priors=result.priors,
+            conversion_function=conversion_function,
+            npool=npool,
+        )
 
     if save:
         # The overwrite here ensures we overwrite the initially stored data
@@ -277,7 +278,7 @@ def run_sampler(
 
     if plot:
         result.plot_corner()
-    logger.info("Summary of results:\n{}".format(result))
+    logger.info(f"Summary of results:\n{result}")
     return result
 
 
@@ -286,7 +287,5 @@ def _check_marginalized_parameters_not_sampled(likelihood, priors):
         if key in priors:
             if not isinstance(priors[key], (float, DeltaFunction)):
                 raise SamplingMarginalisedParameterError(
-                    "Likelihood is {} marginalized but you are trying to sample in {}. ".format(
-                        key, key
-                    )
+                    f"Likelihood is {key} marginalized but you are trying to sample in {key}. "
                 )
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 215104a98087bb91abc3964684edf1a0a5d0d458..380ca001267eca8f09e7a9d02c0a6520e11850b3 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -1,18 +1,111 @@
 import datetime
 import distutils.dir_util
-import numpy as np
 import os
+import shutil
+import signal
+import sys
 import tempfile
+import time
 
+import attr
+import numpy as np
 from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir, command_line_args, Counter
-from ..prior import Prior, PriorDict, DeltaFunction, Constraint
+from ..prior import Constraint, DeltaFunction, Prior, PriorDict
 from ..result import Result, read_in_result
+from ..utils import (
+    Counter,
+    check_directory_exists_and_if_not_mkdir,
+    command_line_args,
+    logger,
+)
+
+
+@attr.s
+class _SamplingContainer:
+    """
+    A container class for objects that are stored independently in each thread
+    for some samplers.
+
+    A single instance of this will appear in this module that can be access
+    by the individual samplers.
+
+    This includes the:
+
+    - likelihood (bilby.core.likelihood.Likelihood)
+    - priors (bilby.core.prior.PriorDict)
+    - search_parameter_keys (list)
+    - use_ratio (bool)
+    """
+
+    likelihood = attr.ib(default=None)
+    priors = attr.ib(default=None)
+    search_parameter_keys = attr.ib(default=None)
+    use_ratio = attr.ib(default=False)
+
+
+_sampling_convenience_dump = _SamplingContainer()
+
+
+def _initialize_global_variables(
+    likelihood,
+    priors,
+    search_parameter_keys,
+    use_ratio,
+):
+    """
+    Store a global copy of the likelihood, priors, and search keys for
+    multiprocessing.
+    """
+    global _sampling_convenience_dump
+    _sampling_convenience_dump.likelihood = likelihood
+    _sampling_convenience_dump.priors = priors
+    _sampling_convenience_dump.search_parameter_keys = search_parameter_keys
+    _sampling_convenience_dump.use_ratio = use_ratio
+
+
+def signal_wrapper(method):
+    """
+    Decorator to wrap a method of a class to set system signals before running
+    and reset them after.
+
+    Parameters
+    ==========
+    method: callable
+        The method to call, this assumes the first argument is `self`
+        and that `self` has a `write_current_state_and_exit` method.
+
+    Returns
+    =======
+    output: callable
+        The wrapped method.
+    """
+
+    def wrapped(self, *args, **kwargs):
+        try:
+            old_term = signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
+            old_int = signal.signal(signal.SIGINT, self.write_current_state_and_exit)
+            old_alarm = signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
+            _set = True
+        except (AttributeError, ValueError):
+            _set = False
+            logger.debug(
+                "Setting signal attributes unavailable on this system. "
+                "This is likely the case if you are running on a Windows machine "
+                "and can be safely ignored."
+            )
+        output = method(self, *args, **kwargs)
+        if _set:
+            signal.signal(signal.SIGTERM, old_term)
+            signal.signal(signal.SIGINT, old_int)
+            signal.signal(signal.SIGALRM, old_alarm)
+        return output
+
+    return wrapped
 
 
 class Sampler(object):
-    """ A sampler object to aid in setting up an inference run
+    """A sampler object to aid in setting up an inference run
 
     Parameters
     ==========
@@ -76,6 +169,10 @@ class Sampler(object):
         System exit code to return on interrupt
     kwargs: dict
         Dictionary of keyword arguments that can be used in the external sampler
+    hard_exit: bool
+        Whether the implemented sampler exits hard (:code:`os._exit` rather
+        than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`.
+        The former cannot.
 
     Raises
     ======
@@ -89,15 +186,42 @@ class Sampler(object):
         If some of the priors can't be sampled
 
     """
+
     default_kwargs = dict()
-    npool_equiv_kwargs = ['queue_size', 'threads', 'nthreads', 'npool']
+    npool_equiv_kwargs = [
+        "npool",
+        "queue_size",
+        "threads",
+        "nthreads",
+        "cores",
+        "n_pool",
+    ]
+    sampling_seed_equiv_kwargs = ["sampling_seed", "seed", "random_seed"]
+    hard_exit = False
+    sampling_seed_key = None
+    """Name of keyword argument for setting the sampling for the specific sampler.
+    If a specific sampler does not have a sampling seed option, then it should be
+    left as None.
+    """
 
     def __init__(
-            self, likelihood, priors, outdir='outdir', label='label',
-            use_ratio=False, plot=False, skip_import_verification=False,
-            injection_parameters=None, meta_data=None, result_class=None,
-            likelihood_benchmark=False, soft_init=False, exit_code=130,
-            **kwargs):
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        injection_parameters=None,
+        meta_data=None,
+        result_class=None,
+        likelihood_benchmark=False,
+        soft_init=False,
+        exit_code=130,
+        npool=1,
+        **kwargs,
+    ):
         self.likelihood = likelihood
         if isinstance(priors, PriorDict):
             self.priors = priors
@@ -108,6 +232,7 @@ class Sampler(object):
         self.injection_parameters = injection_parameters
         self.meta_data = meta_data
         self.use_ratio = use_ratio
+        self._npool = npool
         if not skip_import_verification:
             self._verify_external_sampler()
         self.external_sampler_function = None
@@ -118,6 +243,7 @@ class Sampler(object):
         self._fixed_parameter_keys = list()
         self._constraint_parameter_keys = list()
         self._initialise_parameters()
+        self._log_information_about_priors_and_likelihood()
 
         self.exit_code = exit_code
 
@@ -159,7 +285,7 @@ class Sampler(object):
 
     @property
     def kwargs(self):
-        """dict: Container for the kwargs. Has more sophisticated logic in subclasses """
+        """dict: Container for the kwargs. Has more sophisticated logic in subclasses"""
         return self._kwargs
 
     @kwargs.setter
@@ -170,8 +296,16 @@ class Sampler(object):
         self._verify_kwargs_against_default_kwargs()
 
     def _translate_kwargs(self, kwargs):
-        """ Template for child classes """
-        pass
+        """Translate keyword arguments.
+
+        Default only translates the sampling seed if the sampler has
+        :code:`sampling_seed_key` set.
+        """
+        if self.sampling_seed_key and self.sampling_seed_key not in kwargs:
+            for equiv in self.sampling_seed_equiv_kwargs:
+                if equiv in kwargs:
+                    kwargs[self.sampling_seed_key] = kwargs.pop(equiv)
+        return kwargs
 
     @property
     def external_sampler_name(self):
@@ -180,10 +314,11 @@ class Sampler(object):
     def _verify_external_sampler(self):
         external_sampler_name = self.external_sampler_name
         try:
-            self.external_sampler = __import__(external_sampler_name)
+            __import__(external_sampler_name)
         except (ImportError, SystemExit):
             raise SamplerNotInstalledError(
-                "Sampler {} is not installed on this system".format(external_sampler_name))
+                f"Sampler {external_sampler_name} is not installed on this system"
+            )
 
     def _verify_kwargs_against_default_kwargs(self):
         """
@@ -195,8 +330,8 @@ class Sampler(object):
         for user_input in self.kwargs.keys():
             if user_input not in args:
                 logger.warning(
-                    "Supplied argument '{}' not an argument of '{}', removing."
-                    .format(user_input, self.__class__.__name__))
+                    f"Supplied argument '{user_input}' not an argument of '{self.__class__.__name__}', removing."
+                )
                 bad_keys.append(user_input)
         for key in bad_keys:
             self.kwargs.pop(key)
@@ -208,8 +343,10 @@ class Sampler(object):
         the respective parameter is fixed.
         """
         for key in self.priors:
-            if isinstance(self.priors[key], Prior) \
-                    and self.priors[key].is_fixed is False:
+            if (
+                isinstance(self.priors[key], Prior)
+                and self.priors[key].is_fixed is False
+            ):
                 self._search_parameter_keys.append(key)
             elif isinstance(self.priors[key], Constraint):
                 self._constraint_parameter_keys.append(key)
@@ -217,11 +354,16 @@ class Sampler(object):
                 self.likelihood.parameters[key] = self.priors[key].sample()
                 self._fixed_parameter_keys.append(key)
 
-        logger.info("Search parameters:")
+    def _log_information_about_priors_and_likelihood(self):
+        logger.info("Analysis priors:")
         for key in self._search_parameter_keys + self._constraint_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key]))
+            logger.info(f"{key}={self.priors[key]}")
         for key in self._fixed_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key].peak))
+            logger.info(f"{key}={self.priors[key].peak}")
+        logger.info(f"Analysis likelihood class: {self.likelihood.__class__}")
+        logger.info(
+            f"Analysis likelihood noise evidence: {self.likelihood.noise_log_likelihood()}"
+        )
 
     def _initialise_result(self, result_class):
         """
@@ -231,27 +373,30 @@ class Sampler(object):
 
         """
         result_kwargs = dict(
-            label=self.label, outdir=self.outdir,
+            label=self.label,
+            outdir=self.outdir,
             sampler=self.__class__.__name__.lower(),
             search_parameter_keys=self._search_parameter_keys,
             fixed_parameter_keys=self._fixed_parameter_keys,
             constraint_parameter_keys=self._constraint_parameter_keys,
-            priors=self.priors, meta_data=self.meta_data,
+            priors=self.priors,
+            meta_data=self.meta_data,
             injection_parameters=self.injection_parameters,
-            sampler_kwargs=self.kwargs, use_ratio=self.use_ratio)
+            sampler_kwargs=self.kwargs,
+            use_ratio=self.use_ratio,
+        )
 
         if result_class is None:
             result = Result(**result_kwargs)
         elif issubclass(result_class, Result):
             result = result_class(**result_kwargs)
         else:
-            raise ValueError(
-                "Input result_class={} not understood".format(result_class))
+            raise ValueError(f"Input result_class={result_class} not understood")
 
         return result
 
     def _verify_parameters(self):
-        """ Evaluate a set of parameters drawn from the prior
+        """Evaluate a set of parameters drawn from the prior
 
         Tests if the likelihood evaluation passes
 
@@ -264,20 +409,22 @@ class Sampler(object):
 
         if self.priors.test_has_redundant_keys():
             raise IllegalSamplingSetError(
-                "Your sampling set contains redundant parameters.")
+                "Your sampling set contains redundant parameters."
+            )
 
         theta = self.priors.sample_subset_constrained_as_array(
-            self.search_parameter_keys, size=1)[:, 0]
+            self.search_parameter_keys, size=1
+        )[:, 0]
         try:
             self.log_likelihood(theta)
         except TypeError as e:
             raise TypeError(
-                "Likelihood evaluation failed with message: \n'{}'\n"
-                "Have you specified all the parameters:\n{}"
-                .format(e, self.likelihood.parameters))
+                f"Likelihood evaluation failed with message: \n'{e}'\n"
+                f"Have you specified all the parameters:\n{self.likelihood.parameters}"
+            )
 
     def _time_likelihood(self, n_evaluations=100):
-        """ Times the likelihood evaluation and print an info message
+        """Times the likelihood evaluation and print an info message
 
         Parameters
         ==========
@@ -289,7 +436,8 @@ class Sampler(object):
         t1 = datetime.datetime.now()
         for _ in range(n_evaluations):
             theta = self.priors.sample_subset_constrained_as_array(
-                self._search_parameter_keys, size=1)[:, 0]
+                self._search_parameter_keys, size=1
+            )[:, 0]
             self.log_likelihood(theta)
         total_time = (datetime.datetime.now() - t1).total_seconds()
         self._log_likelihood_eval_time = total_time / n_evaluations
@@ -298,8 +446,9 @@ class Sampler(object):
             self._log_likelihood_eval_time = np.nan
             logger.info("Unable to measure single likelihood time")
         else:
-            logger.info("Single likelihood evaluation took {:.3e} s"
-                        .format(self._log_likelihood_eval_time))
+            logger.info(
+                f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s"
+            )
 
     def _verify_use_ratio(self):
         """
@@ -309,9 +458,9 @@ class Sampler(object):
         try:
             self.priors.sample_subset(self.search_parameter_keys)
         except (KeyError, AttributeError):
-            logger.error("Cannot sample from priors with keys: {}.".format(
-                self.search_parameter_keys
-            ))
+            logger.error(
+                f"Cannot sample from priors with keys: {self.search_parameter_keys}."
+            )
             raise
         if self.use_ratio is False:
             logger.debug("use_ratio set to False")
@@ -322,14 +471,14 @@ class Sampler(object):
         if self.use_ratio is True and ratio_is_nan:
             logger.warning(
                 "You have requested to use the loglikelihood_ratio, but it "
-                " returns a NaN")
+                " returns a NaN"
+            )
         elif self.use_ratio is None and not ratio_is_nan:
-            logger.debug(
-                "use_ratio not spec. but gives valid answer, setting True")
+            logger.debug("use_ratio not spec. but gives valid answer, setting True")
             self.use_ratio = True
 
     def prior_transform(self, theta):
-        """ Prior transform method that is passed into the external sampler.
+        """Prior transform method that is passed into the external sampler.
 
         Parameters
         ==========
@@ -355,8 +504,7 @@ class Sampler(object):
         float: Joint ln prior probability of theta
 
         """
-        params = {
-            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
         return self.priors.ln_prob(params)
 
     def log_likelihood(self, theta):
@@ -378,8 +526,7 @@ class Sampler(object):
                 self.likelihood_count.increment()
             except AttributeError:
                 pass
-        params = {
-            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
         self.likelihood.parameters.update(params)
         if self.use_ratio:
             return self.likelihood.log_likelihood_ratio()
@@ -387,7 +534,7 @@ class Sampler(object):
             return self.likelihood.log_likelihood()
 
     def get_random_draw_from_prior(self):
-        """ Get a random draw from the prior distribution
+        """Get a random draw from the prior distribution
 
         Returns
         =======
@@ -397,13 +544,12 @@ class Sampler(object):
 
         """
         new_sample = self.priors.sample()
-        draw = np.array(list(new_sample[key]
-                             for key in self._search_parameter_keys))
+        draw = np.array(list(new_sample[key] for key in self._search_parameter_keys))
         self.check_draw(draw)
         return draw
 
     def get_initial_points_from_prior(self, npoints=1):
-        """ Method to draw a set of live points from the prior
+        """Method to draw a set of live points from the prior
 
         This iterates over draws from the prior until all the samples have a
         finite prior and likelihood (relevant for constrained priors).
@@ -457,9 +603,11 @@ class Sampler(object):
         """
         log_p = self.log_prior(theta)
         log_l = self.log_likelihood(theta)
-        return \
-            self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
-            self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
+        return self._check_bad_value(
+            val=log_p, warning=warning, theta=theta, label="prior"
+        ) and self._check_bad_value(
+            val=log_l, warning=warning, theta=theta, label="likelihood"
+        )
 
     @staticmethod
     def _check_bad_value(val, warning, theta, label):
@@ -467,7 +615,7 @@ class Sampler(object):
         bad_values = [np.inf, np.nan_to_num(np.inf)]
         if val in bad_values or np.isnan(val):
             if warning:
-                logger.warning(f'Prior draw {theta} has inf {label}')
+                logger.warning(f"Prior draw {theta} has inf {label}")
             return False
         return True
 
@@ -485,7 +633,7 @@ class Sampler(object):
         raise ValueError("Method not yet implemented")
 
     def _check_cached_result(self):
-        """ Check if the cached data file exists and can be used """
+        """Check if the cached data file exists and can be used"""
 
         if command_line_args.clean:
             logger.debug("Command line argument clean given, forcing rerun")
@@ -493,30 +641,30 @@ class Sampler(object):
             return
 
         try:
-            self.cached_result = read_in_result(
-                outdir=self.outdir, label=self.label)
+            self.cached_result = read_in_result(outdir=self.outdir, label=self.label)
         except IOError:
             self.cached_result = None
 
         if command_line_args.use_cached:
-            logger.debug(
-                "Command line argument cached given, no cache check performed")
+            logger.debug("Command line argument cached given, no cache check performed")
             return
 
         logger.debug("Checking cached data")
         if self.cached_result:
-            check_keys = ['search_parameter_keys', 'fixed_parameter_keys']
+            check_keys = ["search_parameter_keys", "fixed_parameter_keys"]
             use_cache = True
             for key in check_keys:
-                if self.cached_result._check_attribute_match_to_other_object(
-                        key, self) is False:
-                    logger.debug("Cached value {} is unmatched".format(key))
+                if (
+                    self.cached_result._check_attribute_match_to_other_object(key, self)
+                    is False
+                ):
+                    logger.debug(f"Cached value {key} is unmatched")
                     use_cache = False
             try:
                 # Recursive check the dictionaries allowing for numpy arrays
                 np.testing.assert_equal(
                     self.meta_data["likelihood"],
-                    self.cached_result.meta_data["likelihood"]
+                    self.cached_result.meta_data["likelihood"],
                 )
             except AssertionError:
                 use_cache = False
@@ -531,13 +679,12 @@ class Sampler(object):
                 if type(kwargs_print[k]) in (list, np.ndarray):
                     array_repr = np.array(kwargs_print[k])
                     if array_repr.size > 10:
-                        kwargs_print[k] = ('array_like, shape={}'
-                                           .format(array_repr.shape))
+                        kwargs_print[k] = f"array_like, shape={array_repr.shape}"
                 elif type(kwargs_print[k]) == DataFrame:
-                    kwargs_print[k] = ('DataFrame, shape={}'
-                                       .format(kwargs_print[k].shape))
-            logger.info("Using sampler {} with kwargs {}".format(
-                self.__class__.__name__, kwargs_print))
+                    kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}"
+            logger.info(
+                f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}"
+            )
 
     def calc_likelihood_count(self):
         if self.likelihood_benchmark:
@@ -545,15 +692,100 @@ class Sampler(object):
         else:
             return None
 
+    @property
+    def npool(self):
+        for key in self.npool_equiv_kwargs:
+            if key in self.kwargs:
+                return self.kwargs[key]
+        return self._npool
+
+    def _log_interruption(self, signum=None):
+        if signum == 14:
+            logger.info(
+                f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}"
+            )
+        else:
+            logger.info(
+                f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}"
+            )
+
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Make sure that if a pool of jobs is running only the parent tries to
+        checkpoint and exit. Only the parent has a 'pool' attribute.
+
+        For samplers that must hard exit (typically due to non-Python process)
+        use :code:`os._exit` that cannot be excepted. Other samplers exiting
+        can be caught as a :code:`SystemExit`.
+        """
+        if self.npool in (1, None) or getattr(self, "pool", None) is not None:
+            self._log_interruption(signum=signum)
+            self.write_current_state()
+            self._close_pool()
+            if self.hard_exit:
+                os._exit(self.exit_code)
+            else:
+                sys.exit(self.exit_code)
+
+    def _close_pool(self):
+        if getattr(self, "pool", None) is not None:
+            logger.info("Starting to close worker pool.")
+            self.pool.close()
+            self.pool.join()
+            self.pool = None
+            self.kwargs["pool"] = self.pool
+            logger.info("Finished closing worker pool.")
+
+    def _setup_pool(self):
+        if self.kwargs.get("pool", None) is not None:
+            logger.info("Using user defined pool.")
+            self.pool = self.kwargs["pool"]
+        elif self.npool is not None and self.npool > 1:
+            logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
+            import multiprocessing
+
+            self.pool = multiprocessing.Pool(
+                processes=self.npool,
+                initializer=_initialize_global_variables,
+                initargs=(
+                    self.likelihood,
+                    self.priors,
+                    self._search_parameter_keys,
+                    self.use_ratio,
+                ),
+            )
+        else:
+            self.pool = None
+        _initialize_global_variables(
+            likelihood=self.likelihood,
+            priors=self.priors,
+            search_parameter_keys=self._search_parameter_keys,
+            use_ratio=self.use_ratio,
+        )
+        self.kwargs["pool"] = self.pool
+
+    def write_current_state(self):
+        raise NotImplementedError()
+
 
 class NestedSampler(Sampler):
-    npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints',
-                            'npoint', 'Nlive', 'num_live_points', 'num_particles']
-    walks_equiv_kwargs = ['walks', 'steps', 'nmcmc']
+    npoints_equiv_kwargs = [
+        "nlive",
+        "nlives",
+        "n_live_points",
+        "npoints",
+        "npoint",
+        "Nlive",
+        "num_live_points",
+        "num_particles",
+    ]
+    walks_equiv_kwargs = ["walks", "steps", "nmcmc"]
 
-    def reorder_loglikelihoods(self, unsorted_loglikelihoods, unsorted_samples,
-                               sorted_samples):
-        """ Reorders the stored log-likelihood after they have been reweighted
+    @staticmethod
+    def reorder_loglikelihoods(
+        unsorted_loglikelihoods, unsorted_samples, sorted_samples
+    ):
+        """Reorders the stored log-likelihood after they have been reweighted
 
         This creates a sorting index by matching the reweights `result.samples`
         against the raw samples, then uses this index to sort the
@@ -578,12 +810,12 @@ class NestedSampler(Sampler):
 
         idxs = []
         for ii in range(len(unsorted_loglikelihoods)):
-            idx = np.where(np.all(sorted_samples[ii] == unsorted_samples,
-                                  axis=1))[0]
+            idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0]
             if len(idx) > 1:
                 logger.warning(
                     "Multiple likelihood matches found between sorted and "
-                    "unsorted samples. Taking the first match.")
+                    "unsorted samples. Taking the first match."
+                )
             idxs.append(idx[0])
         return unsorted_loglikelihoods[idxs]
 
@@ -601,52 +833,34 @@ class NestedSampler(Sampler):
         =======
         float: log_likelihood
         """
-        if self.priors.evaluate_constraints({
-                key: theta[ii] for ii, key in
-                enumerate(self.search_parameter_keys)}):
+        if self.priors.evaluate_constraints(
+            {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)}
+        ):
             return Sampler.log_likelihood(self, theta)
         else:
             return np.nan_to_num(-np.inf)
 
-    def _setup_run_directory(self):
-        """
-        If using a temporary directory, the output directory is moved to the
-        temporary directory.
-        Used for Dnest4, Pymultinest, and Ultranest.
-        """
-        if self.use_temporary_directory:
-            temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
-            self.temporary_outputfiles_basename = temporary_outputfiles_basename
-
-            if os.path.exists(self.outputfiles_basename):
-                distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename)
-            check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
-
-            self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
-            logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
-        else:
-            check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-            logger.info("Using output file {}".format(self.outputfiles_basename))
-
 
 class MCMCSampler(Sampler):
-    nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws', 'Niter']
-    nburn_equiv_kwargs = ['burn', 'nburn']
+    nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"]
+    nburn_equiv_kwargs = ["burn", "nburn"]
 
     def print_nburn_logging_info(self):
-        """ Prints logging info as to how nburn was calculated """
+        """Prints logging info as to how nburn was calculated"""
         if type(self.nburn) in [float, int]:
-            logger.info("Discarding {} steps for burn-in".format(self.nburn))
+            logger.info(f"Discarding {self.nburn} steps for burn-in")
         elif self.result.max_autocorrelation_time is None:
-            logger.info("Autocorrelation time not calculated, discarding {} "
-                        " steps for burn-in".format(self.nburn))
+            logger.info(
+                f"Autocorrelation time not calculated, discarding "
+                f"{self.nburn} steps for burn-in"
+            )
         else:
-            logger.info("Discarding {} steps for burn-in, estimated from "
-                        "autocorr".format(self.nburn))
+            logger.info(
+                f"Discarding {self.nburn} steps for burn-in, estimated from autocorr"
+            )
 
     def calculate_autocorrelation(self, samples, c=3):
-        """ Uses the `emcee.autocorr` module to estimate the autocorrelation
+        """Uses the `emcee.autocorr` module to estimate the autocorrelation
 
         Parameters
         ==========
@@ -657,35 +871,155 @@ class MCMCSampler(Sampler):
             estimate (default: `3`). See `emcee.autocorr.integrated_time`.
         """
         import emcee
+
         try:
-            self.result.max_autocorrelation_time = int(np.max(
-                emcee.autocorr.integrated_time(samples, c=c)))
-            logger.info("Max autocorr time = {}".format(
-                self.result.max_autocorrelation_time))
+            self.result.max_autocorrelation_time = int(
+                np.max(emcee.autocorr.integrated_time(samples, c=c))
+            )
+            logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}")
         except emcee.autocorr.AutocorrError as e:
             self.result.max_autocorrelation_time = None
-            logger.info("Unable to calculate autocorr time: {}".format(e))
+            logger.info(f"Unable to calculate autocorr time: {e}")
+
+
+class _TemporaryFileSamplerMixin:
+    """
+    A mixin class to handle storing sampler intermediate products in a temporary
+    location. See, e.g., `this SO <https://stackoverflow.com/a/547714>` for a
+    basic background on mixins.
+
+    This class makes sure that any subclasses can seamlessly use the temporary
+    file functionality.
+    """
+
+    short_name = ""
+
+    def __init__(self, temporary_directory, **kwargs):
+        super(_TemporaryFileSamplerMixin, self).__init__(**kwargs)
+        self.use_temporary_directory = temporary_directory
+        self._outputfiles_basename = None
+        self._temporary_outputfiles_basename = None
+
+    def _check_and_load_sampling_time_file(self):
+        if os.path.exists(self.time_file_path):
+            with open(self.time_file_path, "r") as time_file:
+                self.total_sampling_time = float(time_file.readline())
+        else:
+            self.total_sampling_time = 0
+
+    def _calculate_and_save_sampling_time(self):
+        current_time = time.time()
+        new_sampling_time = current_time - self.start_time
+        self.total_sampling_time += new_sampling_time
+
+        with open(self.time_file_path, "w") as time_file:
+            time_file.write(str(self.total_sampling_time))
+
+        self.start_time = current_time
+
+    def _clean_up_run_directory(self):
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
+            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
+
+    @property
+    def outputfiles_basename(self):
+        return self._outputfiles_basename
+
+    @outputfiles_basename.setter
+    def outputfiles_basename(self, outputfiles_basename):
+        if outputfiles_basename is None:
+            outputfiles_basename = f"{self.outdir}/{self.short_name}_{self.label}/"
+        if not outputfiles_basename.endswith("/"):
+            outputfiles_basename += "/"
+        check_directory_exists_and_if_not_mkdir(self.outdir)
+        self._outputfiles_basename = outputfiles_basename
+
+    @property
+    def temporary_outputfiles_basename(self):
+        return self._temporary_outputfiles_basename
+
+    @temporary_outputfiles_basename.setter
+    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
+        if not temporary_outputfiles_basename.endswith("/"):
+            temporary_outputfiles_basename += "/"
+        self._temporary_outputfiles_basename = temporary_outputfiles_basename
+        if os.path.exists(self.outputfiles_basename):
+            shutil.copytree(
+                self.outputfiles_basename, self.temporary_outputfiles_basename
+            )
+
+    def write_current_state(self):
+        self._calculate_and_save_sampling_time()
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
+
+    def _move_temporary_directory_to_proper_path(self):
+        """
+        Move the temporary back to the proper path
+
+        Anything in the proper path at this point is removed including links
+        """
+        self._copy_temporary_directory_contents_to_proper_path()
+        shutil.rmtree(self.temporary_outputfiles_basename)
+
+    def _copy_temporary_directory_contents_to_proper_path(self):
+        """
+        Copy the temporary back to the proper path.
+        Do not delete the temporary directory.
+        """
+        logger.info(
+            f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}"
+        )
+        outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/")
+        distutils.dir_util.copy_tree(
+            self.temporary_outputfiles_basename, outputfiles_basename_stripped
+        )
+
+    def _setup_run_directory(self):
+        """
+        If using a temporary directory, the output directory is moved to the
+        temporary directory.
+        Used for Dnest4, Pymultinest, and Ultranest.
+        """
+        check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
+        if self.use_temporary_directory:
+            temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
+            self.temporary_outputfiles_basename = temporary_outputfiles_basename
+
+            if os.path.exists(self.outputfiles_basename):
+                distutils.dir_util.copy_tree(
+                    self.outputfiles_basename, self.temporary_outputfiles_basename
+                )
+            check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
+
+            self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
+            logger.info(f"Using temporary file {temporary_outputfiles_basename}")
+        else:
+            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
+            logger.info(f"Using output file {self.outputfiles_basename}")
+        self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
 
 
 class Error(Exception):
-    """ Base class for all exceptions raised by this module """
+    """Base class for all exceptions raised by this module"""
 
 
 class SamplerError(Error):
-    """ Base class for Error related to samplers in this module """
+    """Base class for Error related to samplers in this module"""
 
 
 class ResumeError(Error):
-    """ Class for errors arising from resuming runs """
+    """Class for errors arising from resuming runs"""
 
 
 class SamplerNotInstalledError(SamplerError):
-    """ Base class for Error raised by not installed samplers """
+    """Base class for Error raised by not installed samplers"""
 
 
 class IllegalSamplingSetError(Error):
-    """ Class for illegal sets of sampling parameters """
+    """Class for illegal sets of sampling parameters"""
 
 
 class SamplingMarginalisedParameterError(IllegalSamplingSetError):
-    """ Class for errors that occur when sampling over marginalized parameters """
+    """Class for errors that occur when sampling over marginalized parameters"""
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index e64365f2e423c9ac4968e8c9964f4d26cd55bc00..bc3b364656d26bcff0c14e3852bbbd394c5887cf 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -1,18 +1,18 @@
-
 import array
 import copy
+import sys
 
 import numpy as np
 from numpy.lib.recfunctions import structured_to_unstructured
 from pandas import DataFrame
 
-from .base_sampler import NestedSampler
-from .proposal import Sample, JumpProposalCycle
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import NestedSampler, signal_wrapper
+from .proposal import JumpProposalCycle, Sample
 
 
 class Cpnest(NestedSampler):
-    """ bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
+    """bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
 
     All positional and keyword arguments (i.e., the args and kwargs) passed to
     `run_sampler` will be propagated to `cpnest.CPNest`, see documentation
@@ -39,30 +39,46 @@ class Cpnest(NestedSampler):
         {self.outdir}/cpnest_{self.label}/
 
     """
-    default_kwargs = dict(verbose=3, nthreads=1, nlive=500, maxmcmc=1000,
-                          seed=None, poolsize=100, nhamiltonian=0, resume=True,
-                          output=None, proposals=None, n_periodic_checkpoint=8000)
+
+    default_kwargs = dict(
+        verbose=3,
+        nthreads=1,
+        nlive=500,
+        maxmcmc=1000,
+        seed=None,
+        poolsize=100,
+        nhamiltonian=0,
+        resume=True,
+        output=None,
+        proposals=None,
+        n_periodic_checkpoint=8000,
+    )
+    hard_exit = True
+    sampling_seed_key = "seed"
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'nthreads' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "nthreads" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nthreads'] = kwargs.pop(equiv)
+                    kwargs["nthreads"] = kwargs.pop(equiv)
 
-        if 'seed' not in kwargs:
-            logger.warning('No seed provided, cpnest will use 1234.')
+        if "seed" not in kwargs:
+            logger.warning("No seed provided, cpnest will use 1234.")
 
+    @signal_wrapper
     def run_sampler(self):
-        from cpnest import model as cpmodel, CPNest
-        from cpnest.parameter import LivePoint
+        from cpnest import CPNest
+        from cpnest import model as cpmodel
         from cpnest.nest2pos import compute_weights
+        from cpnest.parameter import LivePoint
 
         class Model(cpmodel.Model):
-            """ A wrapper class to pass our log_likelihood into cpnest """
+            """A wrapper class to pass our log_likelihood into cpnest"""
 
             def __init__(self, names, priors):
                 self.names = names
@@ -82,14 +98,16 @@ class Cpnest(NestedSampler):
             def _update_bounds(self):
                 self.bounds = [
                     [self.priors[key].minimum, self.priors[key].maximum]
-                    for key in self.names]
+                    for key in self.names
+                ]
 
             def new_point(self):
                 """Draw a point from the prior"""
                 prior_samples = self.priors.sample()
                 self._update_bounds()
                 point = LivePoint(
-                    self.names, array.array('d', [prior_samples[name] for name in self.names])
+                    self.names,
+                    array.array("d", [prior_samples[name] for name in self.names]),
                 )
                 return point
 
@@ -105,18 +123,14 @@ class Cpnest(NestedSampler):
                     kwarg = remove_kwargs.pop(0)
                 else:
                     raise TypeError("Unable to initialise cpnest sampler")
-                logger.info(
-                    "CPNest init. failed with error {}, please update"
-                    .format(e))
-                logger.info(
-                    "Attempting to rerun with kwarg {} removed".format(kwarg))
+                logger.info(f"CPNest init. failed with error {e}, please update")
+                logger.info(f"Attempting to rerun with kwarg {kwarg} removed")
                 self.kwargs.pop(kwarg)
         try:
             out.run()
-        except SystemExit as e:
-            import sys
-            logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code))
-            sys.exit(self.exit_code)
+        except SystemExit:
+            out.checkpoint()
+            self.write_current_state_and_exit()
 
         if self.plot:
             out.plot()
@@ -125,42 +139,58 @@ class Cpnest(NestedSampler):
         self.result.samples = structured_to_unstructured(
             out.posterior_samples[self.search_parameter_keys]
         )
-        self.result.log_likelihood_evaluations = out.posterior_samples['logL']
-        self.result.nested_samples = DataFrame(out.get_nested_samples(filename=''))
-        self.result.nested_samples.rename(columns=dict(logL='log_likelihood'), inplace=True)
-        _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood),
-                                         np.array(out.NS.state.nlive))
-        self.result.nested_samples['weights'] = np.exp(log_weights)
+        self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
+        self.result.nested_samples = DataFrame(out.get_nested_samples(filename=""))
+        self.result.nested_samples.rename(
+            columns=dict(logL="log_likelihood"), inplace=True
+        )
+        _, log_weights = compute_weights(
+            np.array(self.result.nested_samples.log_likelihood),
+            np.array(out.NS.state.nlive),
+        )
+        self.result.nested_samples["weights"] = np.exp(log_weights)
         self.result.log_evidence = out.NS.state.logZ
         self.result.log_evidence_err = np.sqrt(out.NS.state.info / out.NS.state.nlive)
         self.result.information_gain = out.NS.state.info
         return self.result
 
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Overwrites the base class to make sure that :code:`CPNest` terminates
+        properly as :code:`CPNest` handles all the multiprocessing internally.
+        """
+        self._log_interruption(signum=signum)
+        sys.exit(self.exit_code)
+
     def _verify_kwargs_against_default_kwargs(self):
         """
         Set the directory where the output will be written
         and check resume and checkpoint status.
         """
-        if not self.kwargs['output']:
-            self.kwargs['output'] = \
-                '{}/cpnest_{}/'.format(self.outdir, self.label)
-        if self.kwargs['output'].endswith('/') is False:
-            self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
-        check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
-        if self.kwargs['n_periodic_checkpoint'] and not self.kwargs['resume']:
-            self.kwargs['n_periodic_checkpoint'] = None
+        if not self.kwargs["output"]:
+            self.kwargs["output"] = f"{self.outdir}/cpnest_{self.label}/"
+        if self.kwargs["output"].endswith("/") is False:
+            self.kwargs["output"] = f"{self.kwargs['output']}/"
+        check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
+        if self.kwargs["n_periodic_checkpoint"] and not self.kwargs["resume"]:
+            self.kwargs["n_periodic_checkpoint"] = None
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
     def _resolve_proposal_functions(self):
         from cpnest.proposal import ProposalCycle
-        if 'proposals' in self.kwargs:
-            if self.kwargs['proposals'] is None:
+
+        if "proposals" in self.kwargs:
+            if self.kwargs["proposals"] is None:
                 return
-            if type(self.kwargs['proposals']) == JumpProposalCycle:
-                self.kwargs['proposals'] = dict(mhs=self.kwargs['proposals'], hmc=self.kwargs['proposals'])
-            for key, proposal in self.kwargs['proposals'].items():
+            if type(self.kwargs["proposals"]) == JumpProposalCycle:
+                self.kwargs["proposals"] = dict(
+                    mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"]
+                )
+            for key, proposal in self.kwargs["proposals"].items():
                 if isinstance(proposal, JumpProposalCycle):
-                    self.kwargs['proposals'][key] = cpnest_proposal_cycle_factory(proposal)
+                    self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory(
+                        proposal
+                    )
                 elif isinstance(proposal, ProposalCycle):
                     pass
                 else:
@@ -171,7 +201,6 @@ def cpnest_proposal_factory(jump_proposal):
     import cpnest.proposal
 
     class CPNestEnsembleProposal(cpnest.proposal.EnsembleProposal):
-
         def __init__(self, jp):
             self.jump_proposal = jp
             self.ensemble = None
@@ -181,8 +210,8 @@ def cpnest_proposal_factory(jump_proposal):
 
         def get_sample(self, cpnest_sample, **kwargs):
             sample = Sample.from_cpnest_live_point(cpnest_sample)
-            self.ensemble = kwargs.get('coordinates', self.ensemble)
-            sample = self.jump_proposal(sample=sample, sampler_name='cpnest', **kwargs)
+            self.ensemble = kwargs.get("coordinates", self.ensemble)
+            sample = self.jump_proposal(sample=sample, sampler_name="cpnest", **kwargs)
             self.log_J = self.jump_proposal.log_j
             return self._update_cpnest_sample(cpnest_sample, sample)
 
@@ -203,11 +232,15 @@ def cpnest_proposal_cycle_factory(jump_proposals):
         def __init__(self):
             self.jump_proposals = copy.deepcopy(jump_proposals)
             for i, prop in enumerate(self.jump_proposals.proposal_functions):
-                self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(prop)
+                self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(
+                    prop
+                )
             self.jump_proposals.update_cycle()
-            super(CPNestProposalCycle, self).__init__(proposals=self.jump_proposals.proposal_functions,
-                                                      weights=self.jump_proposals.weights,
-                                                      cyclelength=self.jump_proposals.cycle_length)
+            super(CPNestProposalCycle, self).__init__(
+                proposals=self.jump_proposals.proposal_functions,
+                weights=self.jump_proposals.weights,
+                cyclelength=self.jump_proposals.cycle_length,
+            )
 
         def get_sample(self, old, **kwargs):
             return self.jump_proposals(sample=old, coordinates=self.ensemble, **kwargs)
diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py
index ef80c13e933e4dfd6fcaa5e4c3ea8f113b15928e..5c3d7566e729fb92e5427e9c8b5968c3df6c6abe 100644
--- a/bilby/core/sampler/dnest4.py
+++ b/bilby/core/sampler/dnest4.py
@@ -1,21 +1,17 @@
-import os
-import shutil
-import distutils.dir_util
-import signal
-import time
 import datetime
-import sys
+import time
 
 import numpy as np
 import pandas as pd
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
-from .base_sampler import NestedSampler
+from ..utils import logger
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
 class _DNest4Model(object):
-
-    def __init__(self, log_likelihood_func, from_prior_func, widths, centers, highs, lows):
+    def __init__(
+        self, log_likelihood_func, from_prior_func, widths, centers, highs, lows
+    ):
         """Initialize the DNest4 model.
         Args:
             log_likelihood_func: function
@@ -48,7 +44,7 @@ class _DNest4Model(object):
         """The perturb function to perform Monte Carlo trial moves."""
         idx = np.random.randint(self._n_dim)
 
-        coords[idx] += (self._widths[idx] * (np.random.uniform(size=1) - 0.5))
+        coords[idx] += self._widths[idx] * (np.random.uniform(size=1) - 0.5)
         cw = self._widths[idx]
         cc = self._centers[idx]
 
@@ -59,11 +55,13 @@ class _DNest4Model(object):
     @staticmethod
     def wrap(x, minimum, maximum):
         if maximum <= minimum:
-            raise ValueError("maximum {} <= minimum {}, when trying to wrap coordinates".format(maximum, minimum))
+            raise ValueError(
+                f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates"
+            )
         return (x - minimum) % (maximum - minimum) + minimum
 
 
-class DNest4(NestedSampler):
+class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
 
     """
     Bilby wrapper of DNest4
@@ -100,35 +98,59 @@ class DNest4(NestedSampler):
         If True, prints information during run
     """
 
-    default_kwargs = dict(max_num_levels=20, num_steps=500,
-                          new_level_interval=10000, num_per_step=10000,
-                          thread_steps=1, num_particles=1000, lam=10.0,
-                          beta=100, seed=None, verbose=True, outputfiles_basename=None,
-                          backend='memory')
-
-    def __init__(self, likelihood, priors, outdir="outdir", label="label", use_ratio=False, plot=False,
-                 exit_code=77, skip_import_verification=False, temporary_directory=True, **kwargs):
+    default_kwargs = dict(
+        max_num_levels=20,
+        num_steps=500,
+        new_level_interval=10000,
+        num_per_step=10000,
+        thread_steps=1,
+        num_particles=1000,
+        lam=10.0,
+        beta=100,
+        seed=None,
+        verbose=True,
+        outputfiles_basename=None,
+        backend="memory",
+    )
+    short_name = "dn4"
+    hard_exit = True
+    sampling_seed_key = "seed"
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        exit_code=77,
+        skip_import_verification=False,
+        temporary_directory=True,
+        **kwargs,
+    ):
         super(DNest4, self).__init__(
-            likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-            use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
-            exit_code=exit_code, **kwargs)
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            temporary_directory=temporary_directory,
+            exit_code=exit_code,
+            **kwargs,
+        )
 
         self.num_particles = self.kwargs["num_particles"]
         self.max_num_levels = self.kwargs["max_num_levels"]
         self._verbose = self.kwargs["verbose"]
         self._backend = self.kwargs["backend"]
-        self.use_temporary_directory = temporary_directory
 
         self.start_time = np.nan
         self.sampler = None
         self._information = np.nan
         self._last_live_sample_info = np.nan
-        self._outputfiles_basename = None
-        self._temporary_outputfiles_basename = None
-
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
 
         # Get the estimates of the prior distributions' widths and centers.
         widths = []
@@ -155,13 +177,22 @@ class DNest4(NestedSampler):
         self._highs = np.array(highs)
         self._lows = np.array(lows)
 
-        self._dnest4_model = _DNest4Model(self.log_likelihood, self.get_random_draw_from_prior, self._widths,
-                                          self._centers, self._highs, self._lows)
+        self._dnest4_model = _DNest4Model(
+            self.log_likelihood,
+            self.get_random_draw_from_prior,
+            self._widths,
+            self._centers,
+            self._highs,
+            self._lows,
+        )
 
     def _set_backend(self):
         import dnest4
-        if self._backend == 'csv':
-            return dnest4.backends.CSVBackend("{}/dnest4{}/".format(self.outdir, self.label), sep=" ")
+
+        if self._backend == "csv":
+            return dnest4.backends.CSVBackend(
+                f"{self.outdir}/dnest4{self.label}/", sep=" "
+            )
         else:
             return dnest4.backends.MemoryBackend()
 
@@ -169,6 +200,7 @@ class DNest4(NestedSampler):
         dnest4_keys = ["num_steps", "new_level_interval", "lam", "beta", "seed"]
         self.dnest4_kwargs = {key: self.kwargs[key] for key in dnest4_keys}
 
+    @signal_wrapper
     def run_sampler(self):
         import dnest4
 
@@ -181,31 +213,37 @@ class DNest4(NestedSampler):
         self.start_time = time.time()
 
         self.sampler = dnest4.DNest4Sampler(self._dnest4_model, backend=backend)
-        out = self.sampler.sample(self.max_num_levels,
-                                  num_particles=self.num_particles,
-                                  **self.dnest4_kwargs)
+        out = self.sampler.sample(
+            self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs
+        )
 
         for i, sample in enumerate(out):
             if self._verbose and ((i + 1) % 100 == 0):
                 stats = self.sampler.postprocess()
-                logger.info("Iteration: {0} log(Z): {1}".format(i + 1, stats['log_Z']))
+                logger.info(f"Iteration: {i + 1} log(Z): {stats['log_Z']}")
 
         self._calculate_and_save_sampling_time()
         self._clean_up_run_directory()
 
         stats = self.sampler.postprocess(resample=1)
-        self.result.log_evidence = stats['log_Z']
-        self._information = stats['H']
+        self.result.log_evidence = stats["log_Z"]
+        self._information = stats["H"]
         self.result.log_evidence_err = np.sqrt(self._information / self.num_particles)
 
-        if self._backend == 'memory':
-            self._last_live_sample_info = pd.DataFrame(self.sampler.backend.sample_info[-1])
-            self.result.log_likelihood_evaluations = self._last_live_sample_info['log_likelihood']
+        if self._backend == "memory":
+            self._last_live_sample_info = pd.DataFrame(
+                self.sampler.backend.sample_info[-1]
+            )
+            self.result.log_likelihood_evaluations = self._last_live_sample_info[
+                "log_likelihood"
+            ]
             self.result.samples = np.array(self.sampler.backend.posterior_samples)
         else:
-            sample_info_path = './' + self.kwargs["outputfiles_basename"] + '/sample_info.txt'
-            sample_info = np.genfromtxt(sample_info_path, comments='#', names=True)
-            self.result.log_likelihood_evaluations = sample_info['log_likelihood']
+            sample_info_path = (
+                "./" + self.kwargs["outputfiles_basename"] + "/sample_info.txt"
+            )
+            sample_info = np.genfromtxt(sample_info_path, comments="#", names=True)
+            self.result.log_likelihood_evaluations = sample_info["log_likelihood"]
             self.result.samples = np.array(self.sampler.backend.posterior_samples)
 
         self.result.sampler_output = out
@@ -217,100 +255,12 @@ class DNest4(NestedSampler):
         return self.result
 
     def _translate_kwargs(self, kwargs):
-        if 'num_steps' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "num_steps" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['num_steps'] = kwargs.pop(equiv)
+                    kwargs["num_steps"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None)
         super(DNest4, self)._verify_kwargs_against_default_kwargs()
-
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat'
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, 'r') as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-
-        with open(self.time_file_path, 'w') as time_file:
-            time_file.write(str(self.total_sampling_time))
-
-        self.start_time = current_time
-
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = "{}/dnest4{}/".format(self.outdir, self.label)
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """ Write current state and exit on exit_code """
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        sys.exit(self.exit_code)
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Move the temporary back to the proper path
-
-        Anything in the proper path at this point is removed including links
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
-
-    def _copy_temporary_directory_contents_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path.
-        Do not delete the temporary directory.
-        """
-        logger.info(
-            "Overwriting {} with {}".format(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-        )
-        if self.outputfiles_basename.endswith('/'):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped)
diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py
index 8bb6d647aad013c5be957cfde6311f10f9feda82..ef28f22ddbb2ce099d481c3adebaae1ef1d3b0cd 100644
--- a/bilby/core/sampler/dynamic_dynesty.py
+++ b/bilby/core/sampler/dynamic_dynesty.py
@@ -1,12 +1,10 @@
-
-import os
-import signal
+import datetime
 
 import numpy as np
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import Sampler
-from .dynesty import Dynesty
+from ..utils import logger
+from .base_sampler import Sampler, signal_wrapper
+from .dynesty import Dynesty, _log_likelihood_wrapper, _prior_transform_wrapper
 
 
 class DynamicDynesty(Dynesty):
@@ -62,33 +60,77 @@ class DynamicDynesty(Dynesty):
     resume: bool
         If true, resume run from checkpoint (if available)
     """
-    default_kwargs = dict(bound='multi', sample='rwalk',
-                          verbose=True,
-                          check_point_delta_t=600,
-                          first_update=None,
-                          npdim=None, rstate=None, queue_size=None, pool=None,
-                          use_pool=None,
-                          logl_args=None, logl_kwargs=None,
-                          ptform_args=None, ptform_kwargs=None,
-                          enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
-                          facc=0.5, slices=5,
-                          walks=None, update_interval=0.6,
-                          nlive_init=500, maxiter_init=None, maxcall_init=None,
-                          dlogz_init=0.01, logl_max_init=np.inf, nlive_batch=500,
-                          wt_function=None, wt_kwargs=None, maxiter_batch=None,
-                          maxcall_batch=None, maxiter=None, maxcall=None,
-                          maxbatch=None, stop_function=None, stop_kwargs=None,
-                          use_stop=True, save_bounds=True,
-                          print_progress=True, print_func=None, live_points=None,
-                          )
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
-                 skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600,
-                 resume=True, **kwargs):
-        super(DynamicDynesty, self).__init__(likelihood=likelihood, priors=priors,
-                                             outdir=outdir, label=label, use_ratio=use_ratio,
-                                             plot=plot, skip_import_verification=skip_import_verification,
-                                             **kwargs)
+
+    default_kwargs = dict(
+        bound="multi",
+        sample="rwalk",
+        verbose=True,
+        check_point_delta_t=600,
+        first_update=None,
+        npdim=None,
+        rstate=None,
+        queue_size=None,
+        pool=None,
+        use_pool=None,
+        logl_args=None,
+        logl_kwargs=None,
+        ptform_args=None,
+        ptform_kwargs=None,
+        enlarge=None,
+        bootstrap=None,
+        vol_dec=0.5,
+        vol_check=2.0,
+        facc=0.5,
+        slices=5,
+        walks=None,
+        update_interval=0.6,
+        nlive_init=500,
+        maxiter_init=None,
+        maxcall_init=None,
+        dlogz_init=0.01,
+        logl_max_init=np.inf,
+        nlive_batch=500,
+        wt_function=None,
+        wt_kwargs=None,
+        maxiter_batch=None,
+        maxcall_batch=None,
+        maxiter=None,
+        maxcall=None,
+        maxbatch=None,
+        stop_function=None,
+        stop_kwargs=None,
+        use_stop=True,
+        save_bounds=True,
+        print_progress=True,
+        print_func=None,
+        live_points=None,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        check_point=True,
+        n_check_point=None,
+        check_point_delta_t=600,
+        resume=True,
+        **kwargs,
+    ):
+        super(DynamicDynesty, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
         self.n_check_point = n_check_point
         self.check_point = check_point
         self.resume = resume
@@ -97,39 +139,59 @@ class DynamicDynesty(Dynesty):
             # check_point is set to False.
             if np.isnan(self._log_likelihood_eval_time):
                 self.check_point = False
-            n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
-            n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
+            n_check_point_raw = check_point_delta_t / self._log_likelihood_eval_time
+            n_check_point_rnd = int(float(f"{n_check_point_raw:1.0g}"))
             self.n_check_point = n_check_point_rnd
 
-        self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
-
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
+        self.resume_file = f"{self.outdir}/{self.label}_resume.pickle"
 
     @property
     def external_sampler_name(self):
-        return 'dynesty'
+        return "dynesty"
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['nlive_init', 'maxiter_init', 'maxcall_init', 'dlogz_init',
-                'logl_max_init', 'nlive_batch', 'wt_function', 'wt_kwargs',
-                'maxiter_batch', 'maxcall_batch', 'maxiter', 'maxcall',
-                'maxbatch', 'stop_function', 'stop_kwargs', 'use_stop',
-                'save_bounds', 'print_progress', 'print_func', 'live_points']
+        keys = [
+            "nlive_init",
+            "maxiter_init",
+            "maxcall_init",
+            "dlogz_init",
+            "logl_max_init",
+            "nlive_batch",
+            "wt_function",
+            "wt_kwargs",
+            "maxiter_batch",
+            "maxcall_batch",
+            "maxiter",
+            "maxcall",
+            "maxbatch",
+            "stop_function",
+            "stop_kwargs",
+            "use_stop",
+            "save_bounds",
+            "print_progress",
+            "print_func",
+            "live_points",
+        ]
         return {key: self.kwargs[key] for key in keys}
 
+    @signal_wrapper
     def run_sampler(self):
         import dynesty
+
+        self._setup_pool()
         self.sampler = dynesty.DynamicNestedSampler(
-            loglikelihood=self.log_likelihood,
-            prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.sampler_init_kwargs)
+            loglikelihood=_log_likelihood_wrapper,
+            prior_transform=_prior_transform_wrapper,
+            ndim=self.ndim,
+            **self.sampler_init_kwargs,
+        )
 
         if self.check_point:
             out = self._run_external_sampler_with_checkpointing()
         else:
             out = self._run_external_sampler_without_checkpointing()
+        self._close_pool()
 
         # Flushes the output to force a line break
         if self.kwargs["verbose"]:
@@ -147,13 +209,14 @@ class DynamicDynesty(Dynesty):
         if self.resume:
             resume = self.read_saved_state(continuing=True)
             if resume:
-                logger.info('Resuming from previous run.')
+                logger.info("Resuming from previous run.")
 
         old_ncall = self.sampler.ncall
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxcall'] = self.n_check_point
+        sampler_kwargs["maxcall"] = self.n_check_point
+        self.start_time = datetime.datetime.now()
         while True:
-            sampler_kwargs['maxcall'] += self.n_check_point
+            sampler_kwargs["maxcall"] += self.n_check_point
             self.sampler.run_nested(**sampler_kwargs)
             if self.sampler.ncall == old_ncall:
                 break
@@ -164,27 +227,8 @@ class DynamicDynesty(Dynesty):
         self._remove_checkpoint()
         return self.sampler.results
 
-    def write_current_state(self):
-        """
-        """
-        import dill
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        with open(self.resume_file, 'wb') as file:
-            dill.dump(self, file)
-
-    def read_saved_state(self, continuing=False):
-        """
-        """
-        import dill
-
-        logger.debug("Reading resume file {}".format(self.resume_file))
-        if os.path.isfile(self.resume_file):
-            with open(self.resume_file, 'rb') as file:
-                self = dill.load(file)
-        else:
-            logger.debug(
-                "Failed to read resume file {}".format(self.resume_file))
-            return False
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        Sampler.write_current_state_and_exit(self=self, signum=signum, frame=frame)
 
     def _verify_kwargs_against_default_kwargs(self):
         Sampler._verify_kwargs_against_default_kwargs(self)
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index d9f1ae94bc4995c0b1f0bd88649ff3850e571085..82da609e77a17613f90c248b6b04b7b1c2854bf9 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -1,62 +1,51 @@
 import datetime
 import os
 import sys
-import signal
 import time
 import warnings
 
 import numpy as np
 from pandas import DataFrame
 
+from ..result import rejection_sample
 from ..utils import (
-    logger,
     check_directory_exists_and_if_not_mkdir,
+    latex_plot_format,
+    logger,
     reflect,
     safe_file_dump,
-    latex_plot_format,
 )
-from .base_sampler import Sampler, NestedSampler
-from ..result import rejection_sample
-
-_likelihood = None
-_priors = None
-_search_parameter_keys = None
-_use_ratio = False
-
-
-def _initialize_global_variables(
-        likelihood, priors, search_parameter_keys, use_ratio
-):
-    """
-    Store a global copy of the likelihood, priors, and search keys for
-    multiprocessing.
-    """
-    global _likelihood
-    global _priors
-    global _search_parameter_keys
-    global _use_ratio
-    _likelihood = likelihood
-    _priors = priors
-    _search_parameter_keys = search_parameter_keys
-    _use_ratio = use_ratio
+from .base_sampler import NestedSampler, Sampler, signal_wrapper
 
 
 def _prior_transform_wrapper(theta):
     """Wrapper to the prior transformation. Needed for multiprocessing."""
-    return _priors.rescale(_search_parameter_keys, theta)
+    from .base_sampler import _sampling_convenience_dump
+
+    return _sampling_convenience_dump.priors.rescale(
+        _sampling_convenience_dump.search_parameter_keys, theta
+    )
 
 
 def _log_likelihood_wrapper(theta):
     """Wrapper to the log likelihood. Needed for multiprocessing."""
-    if _priors.evaluate_constraints({
-        key: theta[ii] for ii, key in enumerate(_search_parameter_keys)
-    }):
-        params = {key: t for key, t in zip(_search_parameter_keys, theta)}
-        _likelihood.parameters.update(params)
-        if _use_ratio:
-            return _likelihood.log_likelihood_ratio()
+    from .base_sampler import _sampling_convenience_dump
+
+    if _sampling_convenience_dump.priors.evaluate_constraints(
+        {
+            key: theta[ii]
+            for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys)
+        }
+    ):
+        params = {
+            key: t
+            for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta)
+        }
+        _sampling_convenience_dump.likelihood.parameters.update(params)
+        if _sampling_convenience_dump.use_ratio:
+            return _sampling_convenience_dump.likelihood.log_likelihood_ratio()
         else:
-            return _likelihood.log_likelihood()
+            return _sampling_convenience_dump.likelihood.log_likelihood()
     else:
         return np.nan_to_num(-np.inf)
 
@@ -130,32 +119,77 @@ class Dynesty(NestedSampler):
           e.g., 'interval-10' prints every ten seconds, this does not print every iteration
         - else: print to `stdout` at every iteration
     """
-    default_kwargs = dict(bound='multi', sample='rwalk',
-                          periodic=None, reflective=None,
-                          check_point_delta_t=1800, nlive=1000,
-                          first_update=None, walks=100,
-                          npdim=None, rstate=None, queue_size=1, pool=None,
-                          use_pool=None, live_points=None,
-                          logl_args=None, logl_kwargs=None,
-                          ptform_args=None, ptform_kwargs=None,
-                          enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0,
-                          facc=0.2, slices=5,
-                          update_interval=None, print_func=None,
-                          dlogz=0.1, maxiter=None, maxcall=None,
-                          logl_max=np.inf, add_live=True, print_progress=True,
-                          save_bounds=False, n_effective=None,
-                          maxmcmc=5000, nact=5, print_method="tqdm")
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 check_point=True, check_point_plot=True, n_check_point=None,
-                 check_point_delta_t=600, resume=True, nestcheck=False, exit_code=130, **kwargs):
-
-        super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
-                                      outdir=outdir, label=label, use_ratio=use_ratio,
-                                      plot=plot, skip_import_verification=skip_import_verification,
-                                      exit_code=exit_code,
-                                      **kwargs)
+
+    default_kwargs = dict(
+        bound="multi",
+        sample="rwalk",
+        print_progress=True,
+        periodic=None,
+        reflective=None,
+        check_point_delta_t=1800,
+        nlive=1000,
+        first_update=None,
+        walks=100,
+        npdim=None,
+        rstate=None,
+        queue_size=1,
+        pool=None,
+        use_pool=None,
+        live_points=None,
+        logl_args=None,
+        logl_kwargs=None,
+        ptform_args=None,
+        ptform_kwargs=None,
+        enlarge=1.5,
+        bootstrap=None,
+        vol_dec=0.5,
+        vol_check=8.0,
+        facc=0.2,
+        slices=5,
+        update_interval=None,
+        print_func=None,
+        dlogz=0.1,
+        maxiter=None,
+        maxcall=None,
+        logl_max=np.inf,
+        add_live=True,
+        save_bounds=False,
+        n_effective=None,
+        maxmcmc=5000,
+        nact=5,
+        print_method="tqdm",
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        check_point=True,
+        check_point_plot=True,
+        n_check_point=None,
+        check_point_delta_t=600,
+        resume=True,
+        nestcheck=False,
+        exit_code=130,
+        **kwargs,
+    ):
+        self._translate_kwargs(kwargs)
+        super(Dynesty, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            exit_code=exit_code,
+            **kwargs,
+        )
         self.n_check_point = n_check_point
         self.check_point = check_point
         self.check_point_plot = check_point_plot
@@ -169,77 +203,80 @@ class Dynesty(NestedSampler):
         if self.n_check_point is None:
             self.n_check_point = 1000
         self.check_point_delta_t = check_point_delta_t
-        logger.info("Checkpoint every check_point_delta_t = {}s"
-                    .format(check_point_delta_t))
+        logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s")
 
-        self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
+        self.resume_file = f"{self.outdir}/{self.label}_resume.pickle"
         self.sampling_time = datetime.timedelta()
 
-        try:
-            signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-            signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-            signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-        except AttributeError:
-            logger.debug(
-                "Setting signal attributes unavailable on this system. "
-                "This is likely the case if you are running on a Windows machine"
-                " and is no further concern.")
-
     def __getstate__(self):
-        """ For pickle: remove external_sampler, which can be an unpicklable "module" """
+        """For pickle: remove external_sampler, which can be an unpicklable "module" """
         state = self.__dict__.copy()
         if "external_sampler" in state:
-            del state['external_sampler']
+            del state["external_sampler"]
         return state
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
-                'maxcall', 'logl_max', 'add_live', 'save_bounds',
-                'n_effective']
+        keys = [
+            "dlogz",
+            "print_progress",
+            "print_func",
+            "maxiter",
+            "maxcall",
+            "logl_max",
+            "add_live",
+            "save_bounds",
+            "n_effective",
+        ]
         return {key: self.kwargs[key] for key in keys}
 
     @property
     def sampler_init_kwargs(self):
-        return {key: value
-                for key, value in self.kwargs.items()
-                if key not in self.sampler_function_kwargs}
+        return {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+        }
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'print_progress' not in kwargs:
-            if 'verbose' in kwargs:
-                kwargs['print_progress'] = kwargs.pop('verbose')
-        if 'walks' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "print_progress" not in kwargs:
+            if "verbose" in kwargs:
+                kwargs["print_progress"] = kwargs.pop("verbose")
+        if "walks" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['walks'] = kwargs.pop(equiv)
+                    kwargs["walks"] = kwargs.pop(equiv)
         if "queue_size" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['queue_size'] = kwargs.pop(equiv)
+                    kwargs["queue_size"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         from tqdm.auto import tqdm
-        if not self.kwargs['walks']:
-            self.kwargs['walks'] = 100
-        if not self.kwargs['update_interval']:
-            self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
-        if self.kwargs['print_func'] is None:
-            self.kwargs['print_func'] = self._print_func
+
+        if not self.kwargs["walks"]:
+            self.kwargs["walks"] = 100
+        if not self.kwargs["update_interval"]:
+            self.kwargs["update_interval"] = int(0.6 * self.kwargs["nlive"])
+        if self.kwargs["print_func"] is None:
+            self.kwargs["print_func"] = self._print_func
             print_method = self.kwargs["print_method"]
             if print_method == "tqdm" and self.kwargs["print_progress"]:
                 self.pbar = tqdm(file=sys.stdout)
             elif "interval" in print_method:
                 self._last_print_time = datetime.datetime.now()
-                self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1]))
+                self._print_interval = datetime.timedelta(
+                    seconds=float(print_method.split("-")[1])
+                )
         Sampler._verify_kwargs_against_default_kwargs(self)
 
     def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs):
-        """ Replacing status update for dynesty.result.print_func """
+        """Replacing status update for dynesty.result.print_func"""
         if "interval" in self.kwargs["print_method"]:
             _time = datetime.datetime.now()
             if _time - self._last_print_time < self._print_interval:
@@ -251,17 +288,31 @@ class Dynesty(NestedSampler):
                 total_time = self.sampling_time + _time - self.start_time
 
                 # Remove fractional seconds
-                total_time_str = str(total_time).split('.')[0]
+                total_time_str = str(total_time).split(".")[0]
 
         # Extract results at the current iteration.
-        (worst, ustar, vstar, loglstar, logvol, logwt,
-         logz, logzvar, h, nc, worst_it, boundidx, bounditer,
-         eff, delta_logz) = results
+        (
+            worst,
+            ustar,
+            vstar,
+            loglstar,
+            logvol,
+            logwt,
+            logz,
+            logzvar,
+            h,
+            nc,
+            worst_it,
+            boundidx,
+            bounditer,
+            eff,
+            delta_logz,
+        ) = results
 
         # Adjusting outputs for printing.
         if delta_logz > 1e6:
             delta_logz = np.inf
-        if 0. <= logzvar <= 1e6:
+        if 0.0 <= logzvar <= 1e6:
             logzerr = np.sqrt(logzvar)
         else:
             logzerr = np.nan
@@ -271,38 +322,38 @@ class Dynesty(NestedSampler):
             loglstar = -np.inf
 
         if self.use_ratio:
-            key = 'logz-ratio'
+            key = "logz-ratio"
         else:
-            key = 'logz'
+            key = "logz"
 
         # Constructing output.
         string = list()
-        string.append("bound:{:d}".format(bounditer))
-        string.append("nc:{:3d}".format(nc))
-        string.append("ncall:{:.1e}".format(ncall))
-        string.append("eff:{:0.1f}%".format(eff))
-        string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr))
-        string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz))
+        string.append(f"bound:{bounditer:d}")
+        string.append(f"nc:{nc:3d}")
+        string.append(f"ncall:{ncall:.1e}")
+        string.append(f"eff:{eff:0.1f}%")
+        string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}")
+        string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}")
 
         if self.kwargs["print_method"] == "tqdm":
             self.pbar.set_postfix_str(" ".join(string), refresh=False)
             self.pbar.update(niter - self.pbar.n)
         elif "interval" in self.kwargs["print_method"]:
             formatted = " ".join([total_time_str] + string)
-            print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True)
+            print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True)
         else:
             formatted = " ".join([total_time_str] + string)
-            print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True)
+            print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True)
 
     def _apply_dynesty_boundaries(self):
         self._periodic = list()
         self._reflective = list()
         for ii, key in enumerate(self.search_parameter_keys):
-            if self.priors[key].boundary == 'periodic':
-                logger.debug("Setting periodic boundary for {}".format(key))
+            if self.priors[key].boundary == "periodic":
+                logger.debug(f"Setting periodic boundary for {key}")
                 self._periodic.append(ii)
-            elif self.priors[key].boundary == 'reflective':
-                logger.debug("Setting reflective boundary for {}".format(key))
+            elif self.priors[key].boundary == "reflective":
+                logger.debug(f"Setting reflective boundary for {key}")
                 self._reflective.append(ii)
 
         # The periodic kwargs passed into dynesty allows the parameters to
@@ -312,61 +363,26 @@ class Dynesty(NestedSampler):
         self.kwargs["reflective"] = self._reflective
 
     def nestcheck_data(self, out_file):
-        import nestcheck.data_processing
         import pickle
+
+        import nestcheck.data_processing
+
         ns_run = nestcheck.data_processing.process_dynesty_run(out_file)
-        nestcheck_result = "{}/{}_nestcheck.pickle".format(self.outdir, self.label)
-        with open(nestcheck_result, 'wb') as file_nest:
+        nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle"
+        with open(nestcheck_result, "wb") as file_nest:
             pickle.dump(ns_run, file_nest)
 
-    def _setup_pool(self):
-        if self.kwargs["pool"] is not None:
-            logger.info("Using user defined pool.")
-            self.pool = self.kwargs["pool"]
-        elif self.kwargs["queue_size"] > 1:
-            logger.info(
-                "Setting up multiproccesing pool with {} processes.".format(
-                    self.kwargs["queue_size"]
-                )
-            )
-            import multiprocessing
-            self.pool = multiprocessing.Pool(
-                processes=self.kwargs["queue_size"],
-                initializer=_initialize_global_variables,
-                initargs=(
-                    self.likelihood,
-                    self.priors,
-                    self._search_parameter_keys,
-                    self.use_ratio
-                )
-            )
-        else:
-            _initialize_global_variables(
-                likelihood=self.likelihood,
-                priors=self.priors,
-                search_parameter_keys=self._search_parameter_keys,
-                use_ratio=self.use_ratio
-            )
-            self.pool = None
-        self.kwargs["pool"] = self.pool
-
-    def _close_pool(self):
-        if getattr(self, "pool", None) is not None:
-            logger.info("Starting to close worker pool.")
-            self.pool.close()
-            self.pool.join()
-            self.pool = None
-            self.kwargs["pool"] = self.pool
-            logger.info("Finished closing worker pool.")
-
+    @signal_wrapper
     def run_sampler(self):
-        import dynesty
         import dill
-        logger.info("Using dynesty version {}".format(dynesty.__version__))
+        import dynesty
+
+        logger.info(f"Using dynesty version {dynesty.__version__}")
 
         if self.kwargs.get("sample", "rwalk") == "rwalk":
             logger.info(
-                "Using the bilby-implemented rwalk sample method with ACT estimated walks")
+                "Using the bilby-implemented rwalk sample method with ACT estimated walks"
+            )
             dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
             dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
             if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"):
@@ -375,12 +391,10 @@ class Dynesty(NestedSampler):
                 raise DynestySetupError("Unable to run with nact < 1")
         elif self.kwargs.get("sample") == "rwalk_dynesty":
             self._kwargs["sample"] = "rwalk"
-            logger.info(
-                "Using the dynesty-implemented rwalk sample method")
+            logger.info("Using the dynesty-implemented rwalk sample method")
         elif self.kwargs.get("sample") == "rstagger_dynesty":
             self._kwargs["sample"] = "rstagger"
-            logger.info(
-                "Using the dynesty-implemented rstagger sample method")
+            logger.info("Using the dynesty-implemented rstagger sample method")
 
         self._setup_pool()
 
@@ -388,22 +402,25 @@ class Dynesty(NestedSampler):
             self.resume = self.read_saved_state(continuing=True)
 
         if self.resume:
-            logger.info('Resume file successfully loaded.')
+            logger.info("Resume file successfully loaded.")
         else:
-            if self.kwargs['live_points'] is None:
-                self.kwargs['live_points'] = (
-                    self.get_initial_points_from_prior(self.kwargs['nlive'])
+            if self.kwargs["live_points"] is None:
+                self.kwargs["live_points"] = self.get_initial_points_from_prior(
+                    self.kwargs["nlive"]
                 )
             self.sampler = dynesty.NestedSampler(
                 loglikelihood=_log_likelihood_wrapper,
                 prior_transform=_prior_transform_wrapper,
-                ndim=self.ndim, **self.sampler_init_kwargs
+                ndim=self.ndim,
+                **self.sampler_init_kwargs,
             )
 
+        self.start_time = datetime.datetime.now()
         if self.check_point:
             out = self._run_external_sampler_with_checkpointing()
         else:
             out = self._run_external_sampler_without_checkpointing()
+        self._update_sampling_time()
 
         self._close_pool()
 
@@ -417,8 +434,8 @@ class Dynesty(NestedSampler):
         if self.nestcheck:
             self.nestcheck_data(out)
 
-        dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label)
-        with open(dynesty_result, 'wb') as file:
+        dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle"
+        with open(dynesty_result, "wb") as file:
             dill.dump(out, file)
 
         self._generate_result(out)
@@ -432,21 +449,23 @@ class Dynesty(NestedSampler):
     def _generate_result(self, out):
         import dynesty
         from scipy.special import logsumexp
+
         logwts = out["logwt"]
-        weights = np.exp(logwts - out['logz'][-1])
-        nested_samples = DataFrame(
-            out.samples, columns=self.search_parameter_keys)
-        nested_samples['weights'] = weights
-        nested_samples['log_likelihood'] = out.logl
+        weights = np.exp(logwts - out["logz"][-1])
+        nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys)
+        nested_samples["weights"] = weights
+        nested_samples["log_likelihood"] = out.logl
         self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
         self.result.nested_samples = nested_samples
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
-            unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
-            sorted_samples=self.result.samples)
+            unsorted_loglikelihoods=out.logl,
+            unsorted_samples=out.samples,
+            sorted_samples=self.result.samples,
+        )
         self.result.log_evidence = out.logz[-1]
         self.result.log_evidence_err = out.logzerr[-1]
         self.result.information_gain = out.information[-1]
-        self.result.num_likelihood_evaluations = getattr(self.sampler, 'ncall', 0)
+        self.result.num_likelihood_evaluations = getattr(self.sampler, "ncall", 0)
 
         logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2)
         neffsamples = int(np.exp(logneff))
@@ -454,11 +473,16 @@ class Dynesty(NestedSampler):
             nlikelihood=self.result.num_likelihood_evaluations,
             neffsamples=neffsamples,
             sampling_time_s=self.sampling_time.seconds,
-            ncores=self.kwargs.get("queue_size", 1)
+            ncores=self.kwargs.get("queue_size", 1),
         )
 
+    def _update_sampling_time(self):
+        end_time = datetime.datetime.now()
+        self.sampling_time += end_time - self.start_time
+        self.start_time = end_time
+
     def _run_nested_wrapper(self, kwargs):
-        """ Wrapper function to run_nested
+        """Wrapper function to run_nested
 
         This wrapper catches exceptions related to different versions of
         dynesty accepting different arguments.
@@ -469,8 +493,7 @@ class Dynesty(NestedSampler):
             The dictionary of kwargs to pass to run_nested
 
         """
-        logger.debug("Calling run_nested with sampler_function_kwargs {}"
-                     .format(kwargs))
+        logger.debug(f"Calling run_nested with sampler_function_kwargs {kwargs}")
         try:
             self.sampler.run_nested(**kwargs)
         except TypeError:
@@ -487,9 +510,8 @@ class Dynesty(NestedSampler):
 
         old_ncall = self.sampler.ncall
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxcall'] = self.n_check_point
-        sampler_kwargs['add_live'] = True
-        self.start_time = datetime.datetime.now()
+        sampler_kwargs["maxcall"] = self.n_check_point
+        sampler_kwargs["add_live"] = True
         while True:
             self._run_nested_wrapper(sampler_kwargs)
             if self.sampler.ncall == old_ncall:
@@ -499,14 +521,16 @@ class Dynesty(NestedSampler):
             if os.path.isfile(self.resume_file):
                 last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
             else:
-                last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds()
+                last_checkpoint_s = (
+                    datetime.datetime.now() - self.start_time
+                ).total_seconds()
             if last_checkpoint_s > self.check_point_delta_t:
                 self.write_current_state()
                 self.plot_current_state()
             if self.sampler.added_live:
                 self.sampler._remove_live_points()
 
-        sampler_kwargs['add_live'] = True
+        sampler_kwargs["add_live"] = True
         self._run_nested_wrapper(sampler_kwargs)
         self.write_current_state()
         self.plot_current_state()
@@ -534,39 +558,41 @@ class Dynesty(NestedSampler):
             Whether the run is continuing or terminating, if True, the loaded
             state is mostly written back to disk.
         """
-        from ... import __version__ as bilby_version
-        from dynesty import __version__ as dynesty_version
         import dill
+        from dynesty import __version__ as dynesty_version
+
+        from ... import __version__ as bilby_version
+
         versions = dict(bilby=bilby_version, dynesty=dynesty_version)
         if os.path.isfile(self.resume_file):
-            logger.info("Reading resume file {}".format(self.resume_file))
-            with open(self.resume_file, 'rb') as file:
+            logger.info(f"Reading resume file {self.resume_file}")
+            with open(self.resume_file, "rb") as file:
                 sampler = dill.load(file)
 
                 if not hasattr(sampler, "versions"):
                     logger.warning(
-                        "The resume file {} is corrupted or the version of "
-                        "bilby has changed between runs. This resume file will "
-                        "be ignored."
-                        .format(self.resume_file)
+                        f"The resume file {self.resume_file} is corrupted or "
+                        "the version of bilby has changed between runs. This "
+                        "resume file will be ignored."
                     )
                     return False
                 version_warning = (
                     "The {code} version has changed between runs. "
                     "This may cause unpredictable behaviour and/or failure. "
                     "Old version = {old}, new version = {new}."
-
                 )
                 for code in versions:
                     if not versions[code] == sampler.versions.get(code, None):
-                        logger.warning(version_warning.format(
-                            code=code,
-                            old=sampler.versions.get(code, "None"),
-                            new=versions[code]
-                        ))
+                        logger.warning(
+                            version_warning.format(
+                                code=code,
+                                old=sampler.versions.get(code, "None"),
+                                new=versions[code],
+                            )
+                        )
                 del sampler.versions
                 self.sampler = sampler
-                if self.sampler.added_live and continuing:
+                if getattr(self.sampler, "added_live", False) and continuing:
                     self.sampler._remove_live_points()
                 self.sampler.nqueue = -1
                 self.sampler.rstate = np.random
@@ -579,27 +605,13 @@ class Dynesty(NestedSampler):
                     self.sampler.M = map
             return True
         else:
-            logger.info(
-                "Resume file {} does not exist.".format(self.resume_file))
+            logger.info(f"Resume file {self.resume_file} does not exist.")
             return False
 
     def write_current_state_and_exit(self, signum=None, frame=None):
-        """
-        Make sure that if a pool of jobs is running only the parent tries to
-        checkpoint and exit. Only the parent has a 'pool' attribute.
-        """
-        if self.kwargs["queue_size"] == 1 or getattr(self, "pool", None) is not None:
-            if signum == 14:
-                logger.info(
-                    "Run interrupted by alarm signal {}: checkpoint and exit on {}"
-                    .format(signum, self.exit_code))
-            else:
-                logger.info(
-                    "Run interrupted by signal {}: checkpoint and exit on {}"
-                    .format(signum, self.exit_code))
-            self.write_current_state()
-            self._close_pool()
-            os._exit(self.exit_code)
+        if self.kwargs["print_method"] == "tqdm":
+            self.pbar.close()
+        super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame)
 
     def write_current_state(self):
         """
@@ -613,29 +625,26 @@ class Dynesty(NestedSampler):
         normal running.
         """
 
-        from ... import __version__ as bilby_version
-        from dynesty import __version__ as dynesty_version
         import dill
+        from dynesty import __version__ as dynesty_version
+
+        from ... import __version__ as bilby_version
 
         if getattr(self, "sampler", None) is None:
             # Sampler not initialized, not able to write current state
             return
 
         check_directory_exists_and_if_not_mkdir(self.outdir)
-        end_time = datetime.datetime.now()
-        if hasattr(self, 'start_time'):
-            self.sampling_time += end_time - self.start_time
-            self.start_time = end_time
+        if hasattr(self, "start_time"):
+            self._update_sampling_time()
             self.sampler.kwargs["sampling_time"] = self.sampling_time
             self.sampler.kwargs["start_time"] = self.start_time
-        self.sampler.versions = dict(
-            bilby=bilby_version, dynesty=dynesty_version
-        )
+        self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version)
         self.sampler.pool = None
         self.sampler.M = map
         if dill.pickles(self.sampler):
             safe_file_dump(self.sampler, self.resume_file, dill)
-            logger.info("Written checkpoint file {}".format(self.resume_file))
+            logger.info(f"Written checkpoint file {self.resume_file}")
         else:
             logger.warning(
                 "Cannot write pickle resume file! "
@@ -657,86 +666,108 @@ class Dynesty(NestedSampler):
         if nsamples < 100:
             return
 
-        filename = "{}/{}_samples.dat".format(self.outdir, self.label)
-        logger.info("Writing {} current samples to {}".format(nsamples, filename))
+        filename = f"{self.outdir}/{self.label}_samples.dat"
+        logger.info(f"Writing {nsamples} current samples to {filename}")
 
         df = DataFrame(samples, columns=self.search_parameter_keys)
-        df.to_csv(filename, index=False, header=True, sep=' ')
+        df.to_csv(filename, index=False, header=True, sep=" ")
 
     def plot_current_state(self):
         import matplotlib.pyplot as plt
+
         if self.check_point_plot:
             import dynesty.plotting as dyplot
-            labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
+
+            labels = [label.replace("_", " ") for label in self.search_parameter_keys]
             try:
-                filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_trace.png"
                 fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
                 fig.tight_layout()
                 fig.savefig(filename)
-            except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e:
+            except (
+                RuntimeError,
+                np.linalg.linalg.LinAlgError,
+                ValueError,
+                OverflowError,
+                Exception,
+            ) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty state plot at checkpoint')
+                logger.warning("Failed to create dynesty state plot at checkpoint")
             finally:
                 plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_trace_unit.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png"
                 from copy import deepcopy
+
                 temp = deepcopy(self.sampler.results)
                 temp["samples"] = temp["samples_u"]
                 fig = dyplot.traceplot(temp, labels=labels)[0]
                 fig.tight_layout()
                 fig.savefig(filename)
-            except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e:
+            except (
+                RuntimeError,
+                np.linalg.linalg.LinAlgError,
+                ValueError,
+                OverflowError,
+                Exception,
+            ) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty unit state plot at checkpoint')
+                logger.warning("Failed to create dynesty unit state plot at checkpoint")
             finally:
                 plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_run.png"
                 fig, axs = dyplot.runplot(
-                    self.sampler.results, logplot=False, use_math_text=False)
+                    self.sampler.results, logplot=False, use_math_text=False
+                )
                 fig.tight_layout()
                 plt.savefig(filename)
             except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty run plot at checkpoint')
+                logger.warning("Failed to create dynesty run plot at checkpoint")
             finally:
-                plt.close('all')
+                plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_stats.png"
                 fig, axs = dynesty_stats_plot(self.sampler)
                 fig.tight_layout()
                 plt.savefig(filename)
             except (RuntimeError, ValueError) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty stats plot at checkpoint')
+                logger.warning("Failed to create dynesty stats plot at checkpoint")
             finally:
-                plt.close('all')
+                plt.close("all")
 
     def generate_trace_plots(self, dynesty_results):
         check_directory_exists_and_if_not_mkdir(self.outdir)
-        filename = '{}/{}_trace.png'.format(self.outdir, self.label)
-        logger.debug("Writing trace plot to {}".format(filename))
+        filename = f"{self.outdir}/{self.label}_trace.png"
+        logger.debug(f"Writing trace plot to {filename}")
         from dynesty import plotting as dyplot
-        fig, axes = dyplot.traceplot(dynesty_results,
-                                     labels=self.result.parameter_labels)
+
+        fig, axes = dyplot.traceplot(
+            dynesty_results, labels=self.result.parameter_labels
+        )
         fig.tight_layout()
         fig.savefig(filename)
 
     def _run_test(self):
         import dynesty
         import pandas as pd
+
         self.sampler = dynesty.NestedSampler(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.sampler_init_kwargs)
+            ndim=self.ndim,
+            **self.sampler_init_kwargs,
+        )
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxiter'] = 2
+        sampler_kwargs["maxiter"] = 2
 
         self.sampler.run_nested(**sampler_kwargs)
         N = 100
-        self.result.samples = pd.DataFrame(
-            self.priors.sample(N))[self.search_parameter_keys].values
+        self.result.samples = pd.DataFrame(self.priors.sample(N))[
+            self.search_parameter_keys
+        ].values
         self.result.nested_samples = self.result.samples
         self.result.log_likelihood_evaluations = np.ones(N)
         self.result.log_evidence = 1
@@ -745,7 +776,7 @@ class Dynesty(NestedSampler):
         return self.result
 
     def prior_transform(self, theta):
-        """ Prior transform method that is passed into the external sampler.
+        """Prior transform method that is passed into the external sampler.
         cube we map this back to [0, 1].
 
         Parameters
@@ -762,25 +793,24 @@ class Dynesty(NestedSampler):
 
 
 def sample_rwalk_bilby(args):
-    """ Modified bilby-implemented version of dynesty.sampling.sample_rwalk """
+    """Modified bilby-implemented version of dynesty.sampling.sample_rwalk"""
     from dynesty.utils import unitcheck
 
     # Unzipping.
-    (u, loglstar, axes, scale,
-     prior_transform, loglikelihood, kwargs) = args
+    (u, loglstar, axes, scale, prior_transform, loglikelihood, kwargs) = args
     rstate = np.random
 
     # Bounds
-    nonbounded = kwargs.get('nonbounded', None)
-    periodic = kwargs.get('periodic', None)
-    reflective = kwargs.get('reflective', None)
+    nonbounded = kwargs.get("nonbounded", None)
+    periodic = kwargs.get("periodic", None)
+    reflective = kwargs.get("reflective", None)
 
     # Setup.
     n = len(u)
-    walks = kwargs.get('walks', 100)  # minimum number of steps
-    maxmcmc = kwargs.get('maxmcmc', 5000)  # Maximum number of steps
-    nact = kwargs.get('nact', 5)  # Number of ACT
-    old_act = kwargs.get('old_act', walks)
+    walks = kwargs.get("walks", 100)  # minimum number of steps
+    maxmcmc = kwargs.get("maxmcmc", 5000)  # Maximum number of steps
+    nact = kwargs.get("nact", 5)  # Number of ACT
+    old_act = kwargs.get("old_act", walks)
 
     # Initialize internal variables
     accept = 0
@@ -848,19 +878,21 @@ def sample_rwalk_bilby(args):
         if accept + reject > walks:
             act = estimate_nmcmc(
                 accept_ratio=accept / (accept + reject + nfail),
-                old_act=old_act, maxmcmc=maxmcmc)
+                old_act=old_act,
+                maxmcmc=maxmcmc,
+            )
 
         # If we've taken too many likelihood evaluations then break
         if accept + reject > maxmcmc:
             warnings.warn(
-                "Hit maximum number of walks {} with accept={}, reject={}, "
-                "and nfail={} try increasing maxmcmc"
-                .format(maxmcmc, accept, reject, nfail))
+                f"Hit maximum number of walks {maxmcmc} with accept={accept},"
+                f" reject={reject}, and nfail={nfail} try increasing maxmcmc"
+            )
             break
 
     # If the act is finite, pick randomly from within the chain
-    if np.isfinite(act) and int(.5 * nact * act) < len(u_list):
-        idx = np.random.randint(int(.5 * nact * act), len(u_list))
+    if np.isfinite(act) and int(0.5 * nact * act) < len(u_list):
+        idx = np.random.randint(int(0.5 * nact * act), len(u_list))
         u = u_list[idx]
         v = v_list[idx]
         logl = logl_list[idx]
@@ -870,7 +902,7 @@ def sample_rwalk_bilby(args):
         v = prior_transform(u)
         logl = loglikelihood(v)
 
-    blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale}
+    blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale}
     kwargs["old_act"] = act
 
     ncall = accept + reject
@@ -878,7 +910,7 @@ def sample_rwalk_bilby(args):
 
 
 def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
-    """ Estimate autocorrelation length of chain using acceptance fraction
+    """Estimate autocorrelation length of chain using acceptance fraction
 
     Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapted from CPNest:
 
@@ -905,9 +937,8 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
     if accept_ratio == 0.0:
         Nmcmc_exact = (1 + 1 / tau) * old_act
     else:
-        Nmcmc_exact = (
-            (1. - 1. / tau) * old_act +
-            (safety / tau) * (2. / accept_ratio - 1.)
+        Nmcmc_exact = (1.0 - 1.0 / tau) * old_act + (safety / tau) * (
+            2.0 / accept_ratio - 1.0
         )
         Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
     return max(safety, int(Nmcmc_exact))
@@ -943,7 +974,7 @@ def dynesty_stats_plot(sampler):
 
     fig, axs = plt.subplots(nrows=4, figsize=(8, 8))
     for ax, name in zip(axs, ["nc", "scale"]):
-        ax.plot(getattr(sampler, "saved_{}".format(name)), color="blue")
+        ax.plot(getattr(sampler, f"saved_{name}"), color="blue")
         ax.set_ylabel(name.title())
     lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it
     axs[-2].set_ylabel("Lifetime")
@@ -951,9 +982,17 @@ def dynesty_stats_plot(sampler):
     burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive))
     if len(sampler.saved_it) > burn + sampler.nlive:
         axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey")
-        axs[-2].plot(np.arange(burn, len(lifetimes) - nlive), lifetimes[burn: -nlive], color="blue")
-        axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
-        lifetimes = lifetimes[burn: -nlive]
+        axs[-2].plot(
+            np.arange(burn, len(lifetimes) - nlive),
+            lifetimes[burn:-nlive],
+            color="blue",
+        )
+        axs[-2].plot(
+            np.arange(len(lifetimes) - nlive, len(lifetimes)),
+            lifetimes[-nlive:],
+            color="red",
+        )
+        lifetimes = lifetimes[burn:-nlive]
         ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf)
         axs[-1].hist(
             lifetimes,
@@ -961,19 +1000,25 @@ def dynesty_stats_plot(sampler):
             histtype="step",
             density=True,
             color="blue",
-            label=f"p value = {ks_result.pvalue:.3f}"
+            label=f"p value = {ks_result.pvalue:.3f}",
         )
         axs[-1].plot(
             np.arange(1, 6 * nlive),
             geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)),
-            color="red"
+            color="red",
         )
         axs[-1].set_xlim(0, 6 * nlive)
         axs[-1].legend()
         axs[-1].set_yscale("log")
     else:
-        axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey")
-        axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
+        axs[-2].plot(
+            np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey"
+        )
+        axs[-2].plot(
+            np.arange(len(lifetimes) - nlive, len(lifetimes)),
+            lifetimes[-nlive:],
+            color="red",
+        )
     axs[-2].set_yscale("log")
     axs[-2].set_xlabel("Iteration")
     axs[-1].set_xlabel("Lifetime")
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index d976d391ed34f0871b3dc3121f6924ee395cee41..18a36fd1371aed4ccfb47d0c412cb57c77fb803b 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -1,7 +1,5 @@
 import os
-import signal
 import shutil
-import sys
 from collections import namedtuple
 from distutils.version import LooseVersion
 from shutil import copyfile
@@ -9,8 +7,11 @@ from shutil import copyfile
 import numpy as np
 from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import MCMCSampler, SamplerError
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import MCMCSampler, SamplerError, signal_wrapper
+from .ptemcee import LikePriorEvaluator
+
+_evaluator = LikePriorEvaluator()
 
 
 class Emcee(MCMCSampler):
@@ -45,81 +46,111 @@ class Emcee(MCMCSampler):
     """
 
     default_kwargs = dict(
-        nwalkers=500, a=2, args=[], kwargs={}, postargs=None, pool=None,
-        live_dangerously=False, runtime_sortingfn=None, lnprob0=None,
-        rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True,
-        mh_proposal=None)
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
-                 burn_in_act=3, verbose=True, **kwargs):
+        nwalkers=500,
+        a=2,
+        args=[],
+        kwargs={},
+        postargs=None,
+        pool=None,
+        live_dangerously=False,
+        runtime_sortingfn=None,
+        lnprob0=None,
+        rstate0=None,
+        blobs0=None,
+        iterations=100,
+        thin=1,
+        storechain=True,
+        mh_proposal=None,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        pos0=None,
+        nburn=None,
+        burn_in_fraction=0.25,
+        resume=True,
+        burn_in_act=3,
+        **kwargs,
+    ):
         import emcee
-        self.emcee = emcee
 
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+        if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"):
             self.prerelease = True
         else:
             self.prerelease = False
         super(Emcee, self).__init__(
-            likelihood=likelihood, priors=priors, outdir=outdir,
-            label=label, use_ratio=use_ratio, plot=plot,
-            skip_import_verification=skip_import_verification, **kwargs)
-        self.emcee = self._check_version()
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+        self._check_version()
         self.resume = resume
         self.pos0 = pos0
         self.nburn = nburn
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
-        self.verbose = verbose
-
-        signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
-        signal.signal(signal.SIGINT, self.checkpoint_and_exit)
+        self.verbose = kwargs.get("verbose", True)
 
     def _check_version(self):
         import emcee
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+
+        if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"):
             self.prerelease = True
         else:
             self.prerelease = False
         return emcee
 
     def _translate_kwargs(self, kwargs):
-        if 'nwalkers' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nwalkers'] = kwargs.pop(equiv)
-        if 'iterations' not in kwargs:
-            if 'nsteps' in kwargs:
-                kwargs['iterations'] = kwargs.pop('nsteps')
-        if 'threads' in kwargs:
-            if kwargs['threads'] != 1:
-                logger.warning("The 'threads' argument cannot be used for "
-                               "parallelisation. This run will proceed "
-                               "without parallelisation, but consider the use "
-                               "of an appropriate Pool object passed to the "
-                               "'pool' keyword.")
-                kwargs['threads'] = 1
+                    kwargs["nwalkers"] = kwargs.pop(equiv)
+        if "iterations" not in kwargs:
+            if "nsteps" in kwargs:
+                kwargs["iterations"] = kwargs.pop("nsteps")
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin',
-                'storechain', 'mh_proposal']
+        keys = [
+            "lnprob0",
+            "rstate0",
+            "blobs0",
+            "iterations",
+            "thin",
+            "storechain",
+            "mh_proposal",
+        ]
 
         # updated function keywords for emcee > v2.2.1
-        updatekeys = {'p0': 'initial_state',
-                      'lnprob0': 'log_prob0',
-                      'storechain': 'store'}
+        updatekeys = {
+            "p0": "initial_state",
+            "lnprob0": "log_prob0",
+            "storechain": "store",
+        }
 
         function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
-        function_kwargs['p0'] = self.pos0
+        function_kwargs["p0"] = self.pos0
 
         if self.prerelease:
-            if function_kwargs['mh_proposal'] is not None:
-                logger.warning("The 'mh_proposal' option is no longer used "
-                               "in emcee v{}, and will be ignored.".format(
-                                   self.emcee.__version__))
-            del function_kwargs['mh_proposal']
+            if function_kwargs["mh_proposal"] is not None:
+                logger.warning(
+                    "The 'mh_proposal' option is no longer used "
+                    "in emcee > 2.2.1, and will be ignored."
+                )
+            del function_kwargs["mh_proposal"]
 
             for key in updatekeys:
                 if updatekeys[key] not in function_kwargs:
@@ -131,37 +162,30 @@ class Emcee(MCMCSampler):
 
     @property
     def sampler_init_kwargs(self):
-        init_kwargs = {key: value
-                       for key, value in self.kwargs.items()
-                       if key not in self.sampler_function_kwargs}
+        init_kwargs = {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+        }
 
-        init_kwargs['lnpostfn'] = self.lnpostfn
-        init_kwargs['dim'] = self.ndim
+        init_kwargs["lnpostfn"] = _evaluator.call_emcee
+        init_kwargs["dim"] = self.ndim
 
         # updated init keywords for emcee > v2.2.1
-        updatekeys = {'dim': 'ndim',
-                      'lnpostfn': 'log_prob_fn'}
+        updatekeys = {"dim": "ndim", "lnpostfn": "log_prob_fn"}
 
         if self.prerelease:
             for key in updatekeys:
                 if key in init_kwargs:
                     init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
 
-            oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal']
+            oldfunckeys = ["p0", "lnprob0", "storechain", "mh_proposal"]
             for key in oldfunckeys:
                 if key in init_kwargs:
                     del init_kwargs[key]
 
         return init_kwargs
 
-    def lnpostfn(self, theta):
-        log_prior = self.log_prior(theta)
-        if np.isinf(log_prior):
-            return -np.inf, [np.nan, np.nan]
-        else:
-            log_likelihood = self.log_likelihood(theta)
-            return log_likelihood + log_prior, [log_likelihood, log_prior]
-
     @property
     def nburn(self):
         if type(self.__nburn) in [float, int]:
@@ -174,52 +198,54 @@ class Emcee(MCMCSampler):
     @nburn.setter
     def nburn(self, nburn):
         if isinstance(nburn, (float, int)):
-            if nburn > self.kwargs['iterations'] - 1:
-                raise ValueError('Number of burn-in samples must be smaller '
-                                 'than the total number of iterations')
+            if nburn > self.kwargs["iterations"] - 1:
+                raise ValueError(
+                    "Number of burn-in samples must be smaller "
+                    "than the total number of iterations"
+                )
 
         self.__nburn = nburn
 
     @property
     def nwalkers(self):
-        return self.kwargs['nwalkers']
+        return self.kwargs["nwalkers"]
 
     @property
     def nsteps(self):
-        return self.kwargs['iterations']
+        return self.kwargs["iterations"]
 
     @nsteps.setter
     def nsteps(self, nsteps):
-        self.kwargs['iterations'] = nsteps
+        self.kwargs["iterations"] = nsteps
 
     @property
     def stored_chain(self):
-        """ Read the stored zero-temperature chain data in from disk """
+        """Read the stored zero-temperature chain data in from disk"""
         return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
 
     @property
     def stored_samples(self):
-        """ Returns the samples stored on disk """
+        """Returns the samples stored on disk"""
         return self.stored_chain[self.search_parameter_keys]
 
     @property
     def stored_loglike(self):
-        """ Returns the log-likelihood stored on disk """
-        return self.stored_chain['log_l']
+        """Returns the log-likelihood stored on disk"""
+        return self.stored_chain["log_l"]
 
     @property
     def stored_logprior(self):
-        """ Returns the log-prior stored on disk """
-        return self.stored_chain['log_p']
+        """Returns the log-prior stored on disk"""
+        return self.stored_chain["log_p"]
 
     def _init_chain_file(self):
         with open(self.checkpoint_info.chain_file, "w+") as ff:
-            ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
-                '\t'.join(self.search_parameter_keys)))
+            search_keys_str = "\t".join(self.search_parameter_keys)
+            ff.write(f"walker\t{search_keys_str}\tlog_l\tlog_p\n")
 
     @property
     def checkpoint_info(self):
-        """ Defines various things related to checkpointing and storing data
+        """Defines various things related to checkpointing and storing data
 
         Returns
         =======
@@ -231,21 +257,25 @@ class Emcee(MCMCSampler):
 
         """
         out_dir = os.path.join(
-            self.outdir, '{}_{}'.format(self.__class__.__name__.lower(),
-                                        self.label))
+            self.outdir, f"{self.__class__.__name__.lower()}_{self.label}"
+        )
         check_directory_exists_and_if_not_mkdir(out_dir)
 
-        chain_file = os.path.join(out_dir, 'chain.dat')
-        sampler_file = os.path.join(out_dir, 'sampler.pickle')
-        chain_template =\
-            '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
+        chain_file = os.path.join(out_dir, "chain.dat")
+        sampler_file = os.path.join(out_dir, "sampler.pickle")
+        chain_template = (
+            "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n"
+        )
 
         CheckpointInfo = namedtuple(
-            'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
+            "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]
+        )
 
         checkpoint_info = CheckpointInfo(
-            sampler_file=sampler_file, chain_file=chain_file,
-            chain_template=chain_template)
+            sampler_file=sampler_file,
+            chain_file=chain_file,
+            chain_template=chain_template,
+        )
 
         return checkpoint_info
 
@@ -254,43 +284,48 @@ class Emcee(MCMCSampler):
         nsteps = self._previous_iterations
         return self.sampler.chain[:, :nsteps, :]
 
-    def checkpoint(self):
-        """ Writes a pickle file of the sampler to disk using dill """
+    def write_current_state(self):
+        """Writes a pickle file of the sampler to disk using dill"""
         import dill
-        logger.info("Checkpointing sampler to file {}"
-                    .format(self.checkpoint_info.sampler_file))
-        with open(self.checkpoint_info.sampler_file, 'wb') as f:
+
+        logger.info(
+            f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}"
+        )
+        with open(self.checkpoint_info.sampler_file, "wb") as f:
             # Overwrites the stored sampler chain with one that is truncated
             # to only the completed steps
             self.sampler._chain = self.sampler_chain
+            _pool = self.sampler.pool
+            self.sampler.pool = None
             dill.dump(self._sampler, f)
-
-    def checkpoint_and_exit(self, signum, frame):
-        logger.info("Received signal {}".format(signum))
-        self.checkpoint()
-        sys.exit()
+            self.sampler.pool = _pool
 
     def _initialise_sampler(self):
-        self._sampler = self.emcee.EnsembleSampler(**self.sampler_init_kwargs)
+        from emcee import EnsembleSampler
+
+        self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
         self._init_chain_file()
 
     @property
     def sampler(self):
-        """ Returns the emcee sampler object
+        """Returns the emcee sampler object
 
         If, already initialized, returns the stored _sampler value. Otherwise,
         first checks if there is a pickle file from which to load. If there is
         not, then initialize the sampler and set the initial random draw
 
         """
-        if hasattr(self, '_sampler'):
+        if hasattr(self, "_sampler"):
             pass
         elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
             import dill
-            logger.info("Resuming run from checkpoint file {}"
-                        .format(self.checkpoint_info.sampler_file))
-            with open(self.checkpoint_info.sampler_file, 'rb') as f:
+
+            logger.info(
+                f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}"
+            )
+            with open(self.checkpoint_info.sampler_file, "rb") as f:
                 self._sampler = dill.load(f)
+                self._sampler.pool = self.pool
             self._set_pos0_for_resume()
         else:
             self._initialise_sampler()
@@ -299,7 +334,7 @@ class Emcee(MCMCSampler):
 
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
-        temp_chain_file = chain_file + '.temp'
+        temp_chain_file = chain_file + ".temp"
         if os.path.isfile(chain_file):
             copyfile(chain_file, temp_chain_file)
         if self.prerelease:
@@ -313,7 +348,7 @@ class Emcee(MCMCSampler):
 
     @property
     def _previous_iterations(self):
-        """ Returns the number of iterations that the sampler has saved
+        """Returns the number of iterations that the sampler has saved
 
         This is used when loading in a sampler from a pickle file to figure out
         how much of the run has already been completed
@@ -325,7 +360,8 @@ class Emcee(MCMCSampler):
 
     def _draw_pos0_from_prior(self):
         return np.array(
-            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
+            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+        )
 
     @property
     def _pos0_shape(self):
@@ -340,8 +376,7 @@ class Emcee(MCMCSampler):
                 self.pos0 = np.squeeze(self.pos0)
 
             if self.pos0.shape != self._pos0_shape:
-                raise ValueError(
-                    'Input pos0 should be of shape ndim, nwalkers')
+                raise ValueError("Input pos0 should be of shape ndim, nwalkers")
             logger.debug("Checking input pos0")
             for draw in self.pos0:
                 self.check_draw(draw)
@@ -352,38 +387,39 @@ class Emcee(MCMCSampler):
     def _set_pos0_for_resume(self):
         self.pos0 = self.sampler.chain[:, -1, :]
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         from tqdm.auto import tqdm
+
         sampler_function_kwargs = self.sampler_function_kwargs
-        iterations = sampler_function_kwargs.pop('iterations')
+        iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
 
         if self.prerelease:
-            sampler_function_kwargs['initial_state'] = self.pos0
+            sampler_function_kwargs["initial_state"] = self.pos0
         else:
-            sampler_function_kwargs['p0'] = self.pos0
+            sampler_function_kwargs["p0"] = self.pos0
 
         # main iteration loop
-        iterator = self.sampler.sample(
-            iterations=iterations, **sampler_function_kwargs
-        )
+        iterator = self.sampler.sample(iterations=iterations, **sampler_function_kwargs)
         if self.verbose:
             iterator = tqdm(iterator, total=iterations)
         for sample in iterator:
             self.write_chains_to_file(sample)
         if self.verbose:
             iterator.close()
-        self.checkpoint()
+        self.write_current_state()
 
         self.result.sampler_output = np.nan
-        self.calculate_autocorrelation(
-            self.sampler.chain.reshape((-1, self.ndim)))
+        self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
         self.print_nburn_logging_info()
 
         self._generate_result()
 
-        self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape(
-            (-1, self.ndim))
+        self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape(
+            (-1, self.ndim)
+        )
         self.result.walkers = self.sampler.chain
         return self.result
 
@@ -393,10 +429,11 @@ class Emcee(MCMCSampler):
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps` ({} < {}). Try increasing the "
-                "number of steps.".format(self.result.nburn, self.nsteps))
+                f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
+                " Try increasing the number of steps."
+            )
         blobs = np.array(self.sampler.blobs)
-        blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2))
+        blobs_trimmed = blobs[self.nburn :, :, :].reshape((-1, 2))
         log_likelihoods, log_priors = blobs_trimmed.T
         self.result.log_likelihood_evaluations = log_likelihoods
         self.result.log_prior_evaluations = log_priors
diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py
index 8d218472d13f7769c4b88cdd0155cb8f7c04dd0d..5f375fdbad8055e6a5bdaf7dd7e99caabe330f33 100644
--- a/bilby/core/sampler/fake_sampler.py
+++ b/bilby/core/sampler/fake_sampler.py
@@ -1,8 +1,7 @@
-
 import numpy as np
 
-from .base_sampler import Sampler
 from ..result import read_in_result
+from .base_sampler import Sampler
 
 
 class FakeSampler(Sampler):
@@ -17,17 +16,38 @@ class FakeSampler(Sampler):
     sample_file: str
         A string pointing to the posterior data file to be loaded.
     """
-    default_kwargs = dict(verbose=True, logl_args=None, logl_kwargs=None,
-                          print_progress=True)
-
-    def __init__(self, likelihood, priors, sample_file, outdir='outdir',
-                 label='label', use_ratio=False, plot=False,
-                 injection_parameters=None, meta_data=None, result_class=None,
-                 **kwargs):
-        super(FakeSampler, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                          use_ratio=False, plot=False, skip_import_verification=True,
-                                          injection_parameters=None, meta_data=None, result_class=None,
-                                          **kwargs)
+
+    default_kwargs = dict(
+        verbose=True, logl_args=None, logl_kwargs=None, print_progress=True
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        sample_file,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        injection_parameters=None,
+        meta_data=None,
+        result_class=None,
+        **kwargs
+    ):
+        super(FakeSampler, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=False,
+            plot=False,
+            skip_import_verification=True,
+            injection_parameters=None,
+            meta_data=None,
+            result_class=None,
+            **kwargs
+        )
         self._read_parameter_list_from_file(sample_file)
         self.result.outdir = outdir
         self.result.label = label
@@ -41,7 +61,7 @@ class FakeSampler(Sampler):
 
     def run_sampler(self):
         """Compute the likelihood for the list of parameter space points."""
-        self.sampler = 'fake_sampler'
+        self.sampler = "fake_sampler"
 
         # Flushes the output to force a line break
         if self.kwargs["verbose"]:
@@ -59,8 +79,12 @@ class FakeSampler(Sampler):
             likelihood_ratios.append(logl)
 
             if self.kwargs["verbose"]:
-                print(self.likelihood.parameters['log_likelihood'], likelihood_ratios[-1],
-                      self.likelihood.parameters['log_likelihood'] - likelihood_ratios[-1])
+                print(
+                    self.likelihood.parameters["log_likelihood"],
+                    likelihood_ratios[-1],
+                    self.likelihood.parameters["log_likelihood"]
+                    - likelihood_ratios[-1],
+                )
 
         self.result.log_likelihood_evaluations = np.array(likelihood_ratios)
 
diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py
index 83947fc88378c5401508eac458192141cd9f221e..1f09387cc33520a7c8408db7cd7af2a924fa85cf 100644
--- a/bilby/core/sampler/kombine.py
+++ b/bilby/core/sampler/kombine.py
@@ -2,8 +2,12 @@ import os
 
 import numpy as np
 
-from .emcee import Emcee
 from ..utils import logger
+from .base_sampler import signal_wrapper
+from .emcee import Emcee
+from .ptemcee import LikePriorEvaluator
+
+_evaluator = LikePriorEvaluator()
 
 
 class Kombine(Emcee):
@@ -35,21 +39,61 @@ class Kombine(Emcee):
 
     """
 
-    default_kwargs = dict(nwalkers=500, args=[], pool=None, transd=False,
-                          lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None,
-                          kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05,
-                          max_steps=None, burnin_verbose=False)
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
-                 burn_in_act=3, autoburnin=False, **kwargs):
-        super(Kombine, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                      use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
-                                      pos0=pos0, nburn=nburn, burn_in_fraction=burn_in_fraction,
-                                      burn_in_act=burn_in_act, resume=resume, **kwargs)
-
-        if self.kwargs['nwalkers'] > self.kwargs['iterations']:
+    default_kwargs = dict(
+        nwalkers=500,
+        args=[],
+        pool=None,
+        transd=False,
+        lnpost0=None,
+        blob0=None,
+        iterations=500,
+        storechain=True,
+        processes=1,
+        update_interval=None,
+        kde=None,
+        kde_size=None,
+        spaces=None,
+        freeze_transd=False,
+        test_steps=16,
+        critical_pval=0.05,
+        max_steps=None,
+        burnin_verbose=False,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        pos0=None,
+        nburn=None,
+        burn_in_fraction=0.25,
+        resume=True,
+        burn_in_act=3,
+        autoburnin=False,
+        **kwargs,
+    ):
+        super(Kombine, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            pos0=pos0,
+            nburn=nburn,
+            burn_in_fraction=burn_in_fraction,
+            burn_in_act=burn_in_act,
+            resume=resume,
+            **kwargs,
+        )
+
+        if self.kwargs["nwalkers"] > self.kwargs["iterations"]:
             raise ValueError("Kombine Sampler requires Iterations be > nWalkers")
         self.autoburnin = autoburnin
 
@@ -57,42 +101,34 @@ class Kombine(Emcee):
         # set prerelease to False to prevent checks for newer emcee versions in parent class
         self.prerelease = False
 
-    def _translate_kwargs(self, kwargs):
-        if 'nwalkers' not in kwargs:
-            for equiv in self.nwalkers_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs['nwalkers'] = kwargs.pop(equiv)
-        if 'iterations' not in kwargs:
-            if 'nsteps' in kwargs:
-                kwargs['iterations'] = kwargs.pop('nsteps')
-        # make sure processes kwarg is 1
-        if 'processes' in kwargs:
-            if kwargs['processes'] != 1:
-                logger.warning("The 'processes' argument cannot be used for "
-                               "parallelisation. This run will proceed "
-                               "without parallelisation, but consider the use "
-                               "of an appropriate Pool object passed to the "
-                               "'pool' keyword.")
-                kwargs['processes'] = 1
-
     @property
     def sampler_function_kwargs(self):
-        keys = ['lnpost0', 'blob0', 'iterations', 'storechain', 'lnprop0', 'update_interval', 'kde',
-                'kde_size', 'spaces', 'freeze_transd']
+        keys = [
+            "lnpost0",
+            "blob0",
+            "iterations",
+            "storechain",
+            "lnprop0",
+            "update_interval",
+            "kde",
+            "kde_size",
+            "spaces",
+            "freeze_transd",
+        ]
         function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
-        function_kwargs['p0'] = self.pos0
+        function_kwargs["p0"] = self.pos0
         return function_kwargs
 
     @property
     def sampler_burnin_kwargs(self):
-        extra_keys = ['test_steps', 'critical_pval', 'max_steps', 'burnin_verbose']
-        removal_keys = ['iterations', 'spaces', 'freeze_transd']
+        extra_keys = ["test_steps", "critical_pval", "max_steps", "burnin_verbose"]
+        removal_keys = ["iterations", "spaces", "freeze_transd"]
         burnin_kwargs = self.sampler_function_kwargs.copy()
         for key in extra_keys:
             if key in self.kwargs:
                 burnin_kwargs[key] = self.kwargs[key]
-        if 'burnin_verbose' in burnin_kwargs.keys():
-            burnin_kwargs['verbose'] = burnin_kwargs.pop('burnin_verbose')
+        if "burnin_verbose" in burnin_kwargs.keys():
+            burnin_kwargs["verbose"] = burnin_kwargs.pop("burnin_verbose")
         for key in removal_keys:
             if key in burnin_kwargs.keys():
                 burnin_kwargs.pop(key)
@@ -100,19 +136,21 @@ class Kombine(Emcee):
 
     @property
     def sampler_init_kwargs(self):
-        init_kwargs = {key: value
-                       for key, value in self.kwargs.items()
-                       if key not in self.sampler_function_kwargs and key not in self.sampler_burnin_kwargs}
+        init_kwargs = {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+            and key not in self.sampler_burnin_kwargs
+        }
         init_kwargs.pop("burnin_verbose")
-        init_kwargs['lnpostfn'] = self.lnpostfn
-        init_kwargs['ndim'] = self.ndim
+        init_kwargs["lnpostfn"] = _evaluator.call_emcee
+        init_kwargs["ndim"] = self.ndim
 
-        # have to make sure pool is None so sampler will be pickleable
-        init_kwargs['pool'] = None
         return init_kwargs
 
     def _initialise_sampler(self):
         import kombine
+
         self._sampler = kombine.Sampler(**self.sampler_init_kwargs)
         self._init_chain_file()
 
@@ -129,7 +167,9 @@ class Kombine(Emcee):
     def check_resume(self):
         return self.resume and os.path.isfile(self.checkpoint_info.sampler_file)
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         if self.autoburnin:
             if self.check_resume():
                 logger.info("Resuming with autoburnin=True skips burnin process:")
@@ -138,29 +178,50 @@ class Kombine(Emcee):
                 self.sampler.burnin(**self.sampler_burnin_kwargs)
                 self.kwargs["iterations"] += self._previous_iterations
                 self.nburn = self._previous_iterations
-                logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn))
+                logger.info(
+                    f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains"
+                )
                 self._set_pos0_for_resume()
 
         from tqdm.auto import tqdm
+
         sampler_function_kwargs = self.sampler_function_kwargs
-        iterations = sampler_function_kwargs.pop('iterations')
+        iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
-        sampler_function_kwargs['p0'] = self.pos0
+        sampler_function_kwargs["p0"] = self.pos0
         for sample in tqdm(
-                self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
-                total=iterations):
+            self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
+            total=iterations,
+        ):
             self.write_chains_to_file(sample)
-        self.checkpoint()
+        self.write_current_state()
         self.result.sampler_output = np.nan
         if not self.autoburnin:
             tmp_chain = self.sampler.chain.copy()
             self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim)))
             self.print_nburn_logging_info()
+        self._close_pool()
 
         self._generate_result()
         self.result.log_evidence_err = np.nan
 
-        tmp_chain = self.sampler.chain[self.nburn:, :, :].copy()
+        tmp_chain = self.sampler.chain[self.nburn :, :, :].copy()
         self.result.samples = tmp_chain.reshape((-1, self.ndim))
-        self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim))
+        self.result.walkers = self.sampler.chain.reshape(
+            (self.nwalkers, self.nsteps, self.ndim)
+        )
         return self.result
+
+    def _setup_pool(self):
+        from kombine import SerialPool
+
+        super(Kombine, self)._setup_pool()
+        if self.pool is None:
+            self.pool = SerialPool()
+
+    def _close_pool(self):
+        from kombine import SerialPool
+
+        if isinstance(self.pool, SerialPool):
+            self.pool = None
+        super(Kombine, self)._close_pool()
diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py
index d8bb578a4d681c8aca425cd0b4825c651dfa73e0..d0d05037031383ff9a22a08898856e06a6ddbf8d 100644
--- a/bilby/core/sampler/nessai.py
+++ b/bilby/core/sampler/nessai.py
@@ -1,9 +1,10 @@
-import numpy as np
 import os
+
+import numpy as np
 from pandas import DataFrame
 
+from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger
 from .base_sampler import NestedSampler
-from ..utils import logger, check_directory_exists_and_if_not_mkdir, load_json
 
 
 class Nessai(NestedSampler):
@@ -16,8 +17,9 @@ class Nessai(NestedSampler):
 
     Documentation: https://nessai.readthedocs.io/
     """
+
     _default_kwargs = None
-    seed_equiv_kwargs = ['sampling_seed']
+    sampling_seed_key = "seed"
 
     @property
     def default_kwargs(self):
@@ -29,6 +31,7 @@ class Nessai(NestedSampler):
         """
         if not self._default_kwargs:
             from inspect import signature
+
             from nessai.flowsampler import FlowSampler
             from nessai.nestedsampler import NestedSampler
             from nessai.proposal import AugmentedFlowProposal, FlowProposal
@@ -42,12 +45,14 @@ class Nessai(NestedSampler):
             ]
             for c in classes:
                 kwargs.update(
-                    {k: v.default for k, v in signature(c).parameters.items() if v.default is not v.empty}
+                    {
+                        k: v.default
+                        for k, v in signature(c).parameters.items()
+                        if v.default is not v.empty
+                    }
                 )
             # Defaults for bilby that will override nessai defaults
-            bilby_defaults = dict(
-                output=None,
-            )
+            bilby_defaults = dict(output=None, exit_code=self.exit_code)
             kwargs.update(bilby_defaults)
             self._default_kwargs = kwargs
         return self._default_kwargs
@@ -69,8 +74,8 @@ class Nessai(NestedSampler):
 
     def run_sampler(self):
         from nessai.flowsampler import FlowSampler
-        from nessai.model import Model as BaseModel
         from nessai.livepoint import dict_to_live_points, live_points_to_array
+        from nessai.model import Model as BaseModel
         from nessai.posterior import compute_weights
         from nessai.utils import setup_logger
 
@@ -85,6 +90,7 @@ class Nessai(NestedSampler):
                 Priors to use for sampling. Needed for the bounds and the
                 `sample` method.
             """
+
             def __init__(self, names, priors):
                 self.names = names
                 self.priors = priors
@@ -103,8 +109,10 @@ class Nessai(NestedSampler):
                 return self.log_prior(theta)
 
             def _update_bounds(self):
-                self.bounds = {key: [self.priors[key].minimum, self.priors[key].maximum]
-                               for key in self.names}
+                self.bounds = {
+                    key: [self.priors[key].minimum, self.priors[key].maximum]
+                    for key in self.names
+                }
 
             def new_point(self, N=1):
                 """Draw a point from the prior"""
@@ -117,20 +125,22 @@ class Nessai(NestedSampler):
                 return self.log_prior(x)
 
         # Setup the logger for nessai using the same settings as the bilby logger
-        setup_logger(self.outdir, label=self.label,
-                     log_level=logger.getEffectiveLevel())
+        setup_logger(
+            self.outdir, label=self.label, log_level=logger.getEffectiveLevel()
+        )
         model = Model(self.search_parameter_keys, self.priors)
-        out = None
-        while out is None:
-            try:
-                out = FlowSampler(model, **self.kwargs)
-            except TypeError as e:
-                raise TypeError("Unable to initialise nessai sampler with error: {}".format(e))
         try:
+            out = FlowSampler(model, **self.kwargs)
             out.run(save=True, plot=self.plot)
-        except SystemExit as e:
+        except TypeError as e:
+            raise TypeError(f"Unable to initialise nessai sampler with error: {e}")
+        except (SystemExit, KeyboardInterrupt) as e:
             import sys
-            logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code))
+
+            logger.info(
+                f"Caught {type(e).__name__} with args {e.args}, "
+                f"exiting with signal {self.exit_code}"
+            )
             sys.exit(self.exit_code)
 
         # Manually set likelihood evaluations because parallelisation breaks the counter
@@ -139,53 +149,58 @@ class Nessai(NestedSampler):
         self.result.samples = live_points_to_array(
             out.posterior_samples, self.search_parameter_keys
         )
-        self.result.log_likelihood_evaluations = out.posterior_samples['logL']
+        self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
         self.result.nested_samples = DataFrame(out.nested_samples)
         self.result.nested_samples.rename(
-            columns=dict(logL='log_likelihood', logP='log_prior'), inplace=True)
-        _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood),
-                                         np.array(out.ns.state.nlive))
-        self.result.nested_samples['weights'] = np.exp(log_weights)
+            columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True
+        )
+        _, log_weights = compute_weights(
+            np.array(self.result.nested_samples.log_likelihood),
+            np.array(out.ns.state.nlive),
+        )
+        self.result.nested_samples["weights"] = np.exp(log_weights)
         self.result.log_evidence = out.ns.log_evidence
         self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive)
 
         return self.result
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        super()._translate_kwargs(kwargs)
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'n_pool' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "n_pool" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['n_pool'] = kwargs.pop(equiv)
-        if 'seed' not in kwargs:
-            for equiv in self.seed_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs['seed'] = kwargs.pop(equiv)
+                    kwargs["n_pool"] = kwargs.pop(equiv)
+            if "n_pool" not in kwargs:
+                kwargs["n_pool"] = self._npool
 
     def _verify_kwargs_against_default_kwargs(self):
         """
         Set the directory where the output will be written
         and check resume and checkpoint status.
         """
-        if 'config_file' in self.kwargs:
-            d = load_json(self.kwargs['config_file'], None)
+        if "config_file" in self.kwargs:
+            d = load_json(self.kwargs["config_file"], None)
             self.kwargs.update(d)
-            self.kwargs.pop('config_file')
+            self.kwargs.pop("config_file")
 
-        if not self.kwargs['plot']:
-            self.kwargs['plot'] = self.plot
+        if not self.kwargs["plot"]:
+            self.kwargs["plot"] = self.plot
 
-        if self.kwargs['n_pool'] == 1 and self.kwargs['max_threads'] == 1:
-            logger.warning('Setting pool to None (n_pool=1 & max_threads=1)')
-            self.kwargs['n_pool'] = None
+        if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1:
+            logger.warning("Setting pool to None (n_pool=1 & max_threads=1)")
+            self.kwargs["n_pool"] = None
 
-        if not self.kwargs['output']:
-            self.kwargs['output'] = os.path.join(
-                self.outdir, '{}_nessai'.format(self.label), ''
+        if not self.kwargs["output"]:
+            self.kwargs["output"] = os.path.join(
+                self.outdir, f"{self.label}_nessai", ""
             )
 
-        check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
+        check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
         NestedSampler._verify_kwargs_against_default_kwargs(self)
+
+    def _setup_pool(self):
+        pass
diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py
index f598d8b1751b217ae9515019c5963001eb0da840..41318e9628d63e996f50113448b4db36488382f2 100644
--- a/bilby/core/sampler/nestle.py
+++ b/bilby/core/sampler/nestle.py
@@ -1,8 +1,7 @@
-
 import numpy as np
 from pandas import DataFrame
 
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, signal_wrapper
 
 
 class Nestle(NestedSampler):
@@ -25,30 +24,45 @@ class Nestle(NestedSampler):
         sampling
 
     """
-    default_kwargs = dict(verbose=True, method='multi', npoints=500,
-                          update_interval=None, npdim=None, maxiter=None,
-                          maxcall=None, dlogz=None, decline_factor=None,
-                          rstate=None, callback=None, steps=20, enlarge=1.2)
+
+    default_kwargs = dict(
+        verbose=True,
+        method="multi",
+        npoints=500,
+        update_interval=None,
+        npdim=None,
+        maxiter=None,
+        maxcall=None,
+        dlogz=None,
+        decline_factor=None,
+        rstate=None,
+        callback=None,
+        steps=20,
+        enlarge=1.2,
+    )
 
     def _translate_kwargs(self, kwargs):
-        if 'npoints' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "npoints" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['npoints'] = kwargs.pop(equiv)
-        if 'steps' not in kwargs:
+                    kwargs["npoints"] = kwargs.pop(equiv)
+        if "steps" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['steps'] = kwargs.pop(equiv)
+                    kwargs["steps"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
-        if self.kwargs['verbose']:
+        if self.kwargs["verbose"]:
             import nestle
-            self.kwargs['callback'] = nestle.print_progress
-            self.kwargs.pop('verbose')
+
+            self.kwargs["callback"] = nestle.print_progress
+            self.kwargs.pop("verbose")
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
+    @signal_wrapper
     def run_sampler(self):
-        """ Runs Nestle sampler with given kwargs and returns the result
+        """Runs Nestle sampler with given kwargs and returns the result
 
         Returns
         =======
@@ -56,21 +70,27 @@ class Nestle(NestedSampler):
 
         """
         import nestle
+
         out = nestle.sample(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.kwargs)
+            ndim=self.ndim,
+            **self.kwargs
+        )
         print("")
 
         self.result.sampler_output = out
         self.result.samples = nestle.resample_equal(out.samples, out.weights)
         self.result.nested_samples = DataFrame(
-            out.samples, columns=self.search_parameter_keys)
-        self.result.nested_samples['weights'] = out.weights
-        self.result.nested_samples['log_likelihood'] = out.logl
+            out.samples, columns=self.search_parameter_keys
+        )
+        self.result.nested_samples["weights"] = out.weights
+        self.result.nested_samples["log_likelihood"] = out.logl
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
-            unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
-            sorted_samples=self.result.samples)
+            unsorted_loglikelihoods=out.logl,
+            unsorted_samples=out.samples,
+            sorted_samples=self.result.samples,
+        )
         self.result.log_evidence = out.logz
         self.result.log_evidence_err = out.logzerr
         self.result.information_gain = out.h
@@ -88,14 +108,24 @@ class Nestle(NestedSampler):
 
         """
         import nestle
+
         kwargs = self.kwargs.copy()
-        kwargs['maxiter'] = 2
+        kwargs["maxiter"] = 2
         nestle.sample(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **kwargs)
+            ndim=self.ndim,
+            **kwargs
+        )
         self.result.samples = np.random.uniform(0, 1, (100, self.ndim))
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
         self.calc_likelihood_count()
         return self.result
+
+    def write_current_state(self):
+        """
+        Nestle doesn't support checkpointing so no current state will be
+        written on interrupt.
+        """
+        pass
diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py
index 943a5c413abe7e45ff54eb4dde2c9aa8d35b7d91..e43c5d50b248ba0fb12cd8d5bca97b0fee726c45 100644
--- a/bilby/core/sampler/polychord.py
+++ b/bilby/core/sampler/polychord.py
@@ -1,7 +1,6 @@
-
 import numpy as np
 
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, signal_wrapper
 
 
 class PyPolyChord(NestedSampler):
@@ -21,32 +20,67 @@ class PyPolyChord(NestedSampler):
     To see what the keyword arguments are for, see the docstring of PyPolyChordSettings
     """
 
-    default_kwargs = dict(use_polychord_defaults=False, nlive=None, num_repeats=None,
-                          nprior=-1, do_clustering=True, feedback=1, precision_criterion=0.001,
-                          logzero=-1e30, max_ndead=-1, boost_posterior=0.0, posteriors=True,
-                          equals=True, cluster_posteriors=True, write_resume=True,
-                          write_paramnames=False, read_resume=True, write_stats=True,
-                          write_live=True, write_dead=True, write_prior=True,
-                          compression_factor=np.exp(-1), base_dir='outdir',
-                          file_root='polychord', seed=-1, grade_dims=None, grade_frac=None, nlives={})
-
+    default_kwargs = dict(
+        use_polychord_defaults=False,
+        nlive=None,
+        num_repeats=None,
+        nprior=-1,
+        do_clustering=True,
+        feedback=1,
+        precision_criterion=0.001,
+        logzero=-1e30,
+        max_ndead=-1,
+        boost_posterior=0.0,
+        posteriors=True,
+        equals=True,
+        cluster_posteriors=True,
+        write_resume=True,
+        write_paramnames=False,
+        read_resume=True,
+        write_stats=True,
+        write_live=True,
+        write_dead=True,
+        write_prior=True,
+        compression_factor=np.exp(-1),
+        base_dir="outdir",
+        file_root="polychord",
+        seed=-1,
+        grade_dims=None,
+        grade_frac=None,
+        nlives={},
+    )
+    hard_exit = True
+    sampling_seed_key = "seed"
+
+    @signal_wrapper
     def run_sampler(self):
         import pypolychord
         from pypolychord.settings import PolyChordSettings
-        if self.kwargs['use_polychord_defaults']:
-            settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim,
-                                         base_dir=self._sample_file_directory,
-                                         file_root=self.label)
+
+        if self.kwargs["use_polychord_defaults"]:
+            settings = PolyChordSettings(
+                nDims=self.ndim,
+                nDerived=self.ndim,
+                base_dir=self._sample_file_directory,
+                file_root=self.label,
+            )
         else:
             self._setup_dynamic_defaults()
             pc_kwargs = self.kwargs.copy()
-            pc_kwargs['base_dir'] = self._sample_file_directory
-            pc_kwargs['file_root'] = self.label
-            pc_kwargs.pop('use_polychord_defaults')
-            settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim, **pc_kwargs)
+            pc_kwargs["base_dir"] = self._sample_file_directory
+            pc_kwargs["file_root"] = self.label
+            pc_kwargs.pop("use_polychord_defaults")
+            settings = PolyChordSettings(
+                nDims=self.ndim, nDerived=self.ndim, **pc_kwargs
+            )
         self._verify_kwargs_against_default_kwargs()
-        out = pypolychord.run_polychord(loglikelihood=self.log_likelihood, nDims=self.ndim,
-                                        nDerived=self.ndim, settings=settings, prior=self.prior_transform)
+        out = pypolychord.run_polychord(
+            loglikelihood=self.log_likelihood,
+            nDims=self.ndim,
+            nDerived=self.ndim,
+            settings=settings,
+            prior=self.prior_transform,
+        )
         self.result.log_evidence = out.logZ
         self.result.log_evidence_err = out.logZerr
         log_likelihoods, physical_parameters = self._read_sample_file()
@@ -56,24 +90,25 @@ class PyPolyChord(NestedSampler):
         return self.result
 
     def _setup_dynamic_defaults(self):
-        """ Sets up some interdependent default argument if none are given by the user """
-        if not self.kwargs['grade_dims']:
-            self.kwargs['grade_dims'] = [self.ndim]
-        if not self.kwargs['grade_frac']:
-            self.kwargs['grade_frac'] = [1.0] * len(self.kwargs['grade_dims'])
-        if not self.kwargs['nlive']:
-            self.kwargs['nlive'] = self.ndim * 25
-        if not self.kwargs['num_repeats']:
-            self.kwargs['num_repeats'] = self.ndim * 5
+        """Sets up some interdependent default argument if none are given by the user"""
+        if not self.kwargs["grade_dims"]:
+            self.kwargs["grade_dims"] = [self.ndim]
+        if not self.kwargs["grade_frac"]:
+            self.kwargs["grade_frac"] = [1.0] * len(self.kwargs["grade_dims"])
+        if not self.kwargs["nlive"]:
+            self.kwargs["nlive"] = self.ndim * 25
+        if not self.kwargs["num_repeats"]:
+            self.kwargs["num_repeats"] = self.ndim * 5
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
+                    kwargs["nlive"] = kwargs.pop(equiv)
 
     def log_likelihood(self, theta):
-        """ Overrides the log_likelihood so that PolyChord understands it """
+        """Overrides the log_likelihood so that PolyChord understands it"""
         return super(PyPolyChord, self).log_likelihood(theta), theta
 
     def _read_sample_file(self):
@@ -87,12 +122,14 @@ class PyPolyChord(NestedSampler):
         array_like, array_like: The log_likelihoods and the associated parameters
 
         """
-        sample_file = self._sample_file_directory + '/' + self.label + '_equal_weights.txt'
+        sample_file = (
+            self._sample_file_directory + "/" + self.label + "_equal_weights.txt"
+        )
         samples = np.loadtxt(sample_file)
         log_likelihoods = -0.5 * samples[:, 1]
-        physical_parameters = samples[:, -self.ndim:]
+        physical_parameters = samples[:, -self.ndim :]
         return log_likelihoods, physical_parameters
 
     @property
     def _sample_file_directory(self):
-        return self.outdir + '/chains'
+        return self.outdir + "/chains"
diff --git a/bilby/core/sampler/proposal.py b/bilby/core/sampler/proposal.py
index 2d52616588328c8b5bdb02247d20ff5d48b71e8b..023caac5744de968d338cea30125574248b91f95 100644
--- a/bilby/core/sampler/proposal.py
+++ b/bilby/core/sampler/proposal.py
@@ -1,13 +1,12 @@
+import random
 from inspect import isclass
 
 import numpy as np
-import random
 
 from ..prior import Uniform
 
 
 class Sample(dict):
-
     def __init__(self, dictionary=None):
         if dictionary is None:
             dictionary = dict()
@@ -31,15 +30,14 @@ class Sample(dict):
 
     @classmethod
     def from_external_type(cls, external_sample, sampler_name):
-        if sampler_name == 'cpnest':
+        if sampler_name == "cpnest":
             return cls.from_cpnest_live_point(external_sample)
         return external_sample
 
 
 class JumpProposal(object):
-
     def __init__(self, priors=None):
-        """ A generic class for jump proposals
+        """A generic class for jump proposals
 
         Parameters
         ==========
@@ -56,7 +54,7 @@ class JumpProposal(object):
         self.log_j = 0.0
 
     def __call__(self, sample, **kwargs):
-        """ A generic wrapper for the jump proposal function
+        """A generic wrapper for the jump proposal function
 
         Parameters
         ==========
@@ -71,26 +69,35 @@ class JumpProposal(object):
         return self._apply_boundaries(sample)
 
     def _move_reflecting_keys(self, sample):
-        keys = [key for key in sample.keys()
-                if self.priors[key].boundary == 'reflective']
+        keys = [
+            key for key in sample.keys() if self.priors[key].boundary == "reflective"
+        ]
         for key in keys:
-            if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum:
+            if (
+                sample[key] > self.priors[key].maximum
+                or sample[key] < self.priors[key].minimum
+            ):
                 r = self.priors[key].maximum - self.priors[key].minimum
                 delta = (sample[key] - self.priors[key].minimum) % (2 * r)
                 if delta > r:
-                    sample[key] = 2 * self.priors[key].maximum - self.priors[key].minimum - delta
+                    sample[key] = (
+                        2 * self.priors[key].maximum - self.priors[key].minimum - delta
+                    )
                 elif delta < r:
                     sample[key] = self.priors[key].minimum + delta
         return sample
 
     def _move_periodic_keys(self, sample):
-        keys = [key for key in sample.keys()
-                if self.priors[key].boundary == 'periodic']
+        keys = [key for key in sample.keys() if self.priors[key].boundary == "periodic"]
         for key in keys:
-            if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum:
-                sample[key] = (self.priors[key].minimum +
-                               ((sample[key] - self.priors[key].minimum) %
-                                (self.priors[key].maximum - self.priors[key].minimum)))
+            if (
+                sample[key] > self.priors[key].maximum
+                or sample[key] < self.priors[key].minimum
+            ):
+                sample[key] = self.priors[key].minimum + (
+                    (sample[key] - self.priors[key].minimum)
+                    % (self.priors[key].maximum - self.priors[key].minimum)
+                )
         return sample
 
     def _apply_boundaries(self, sample):
@@ -100,9 +107,8 @@ class JumpProposal(object):
 
 
 class JumpProposalCycle(object):
-
     def __init__(self, proposal_functions, weights, cycle_length=100):
-        """ A generic wrapper class for proposal cycles
+        """A generic wrapper class for proposal cycles
 
         Parameters
         ==========
@@ -129,8 +135,12 @@ class JumpProposalCycle(object):
         return len(self.proposal_functions)
 
     def update_cycle(self):
-        self._cycle = np.random.choice(self.proposal_functions, size=self.cycle_length,
-                                       p=self.weights, replace=True)
+        self._cycle = np.random.choice(
+            self.proposal_functions,
+            size=self.cycle_length,
+            p=self.weights,
+            replace=True,
+        )
 
     @property
     def proposal_functions(self):
@@ -190,9 +200,13 @@ class NormJump(JumpProposal):
 
 
 class EnsembleWalk(JumpProposal):
-
-    def __init__(self, random_number_generator=random.random, n_points=3, priors=None,
-                 **random_number_generator_args):
+    def __init__(
+        self,
+        random_number_generator=random.random,
+        n_points=3,
+        priors=None,
+        **random_number_generator_args
+    ):
         """
         An ensemble walk
 
@@ -213,12 +227,16 @@ class EnsembleWalk(JumpProposal):
         self.random_number_generator_args = random_number_generator_args
 
     def __call__(self, sample, **kwargs):
-        subset = random.sample(kwargs['coordinates'], self.n_points)
+        subset = random.sample(kwargs["coordinates"], self.n_points)
         for i in range(len(subset)):
-            subset[i] = Sample.from_external_type(subset[i], kwargs.get('sampler_name', None))
+            subset[i] = Sample.from_external_type(
+                subset[i], kwargs.get("sampler_name", None)
+            )
         center_of_mass = self.get_center_of_mass(subset)
         for x in subset:
-            sample += (x - center_of_mass) * self.random_number_generator(**self.random_number_generator_args)
+            sample += (x - center_of_mass) * self.random_number_generator(
+                **self.random_number_generator_args
+            )
         return super(EnsembleWalk, self).__call__(sample)
 
     @staticmethod
@@ -227,7 +245,6 @@ class EnsembleWalk(JumpProposal):
 
 
 class EnsembleStretch(JumpProposal):
-
     def __init__(self, scale=2.0, priors=None):
         """
         Stretch move. Calculates the log Jacobian which can be used in cpnest to bias future moves.
@@ -241,8 +258,10 @@ class EnsembleStretch(JumpProposal):
         self.scale = scale
 
     def __call__(self, sample, **kwargs):
-        second_sample = random.choice(kwargs['coordinates'])
-        second_sample = Sample.from_external_type(second_sample, kwargs.get('sampler_name', None))
+        second_sample = random.choice(kwargs["coordinates"])
+        second_sample = Sample.from_external_type(
+            second_sample, kwargs.get("sampler_name", None)
+        )
         step = random.uniform(-1, 1) * np.log(self.scale)
         sample = second_sample + (sample - second_sample) * np.exp(step)
         self.log_j = len(sample) * step
@@ -250,7 +269,6 @@ class EnsembleStretch(JumpProposal):
 
 
 class DifferentialEvolution(JumpProposal):
-
     def __init__(self, sigma=1e-4, mu=1.0, priors=None):
         """
         Differential evolution step. Takes two elements from the existing coordinates and differentially evolves the
@@ -268,13 +286,12 @@ class DifferentialEvolution(JumpProposal):
         self.mu = mu
 
     def __call__(self, sample, **kwargs):
-        a, b = random.sample(kwargs['coordinates'], 2)
+        a, b = random.sample(kwargs["coordinates"], 2)
         sample = sample + (b - a) * random.gauss(self.mu, self.sigma)
         return super(DifferentialEvolution, self).__call__(sample)
 
 
 class EnsembleEigenVector(JumpProposal):
-
     def __init__(self, priors=None):
         """
         Ensemble step based on the ensemble eigenvectors.
@@ -316,7 +333,7 @@ class EnsembleEigenVector(JumpProposal):
         self.eigen_values, self.eigen_vectors = np.linalg.eigh(self.covariance)
 
     def __call__(self, sample, **kwargs):
-        self.update_eigenvectors(kwargs['coordinates'])
+        self.update_eigenvectors(kwargs["coordinates"])
         i = random.randrange(len(sample))
         jump_size = np.sqrt(np.fabs(self.eigen_values[i])) * random.gauss(0, 1)
         for j, key in enumerate(sample.keys()):
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 6191af0ec51ccbb984ef034d317f26b544cb8e9b..2534b0369d1de8d1b75e8561f3cafb4cda26ec73 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -2,17 +2,19 @@ import copy
 import datetime
 import logging
 import os
-import signal
-import sys
 import time
 from collections import namedtuple
 
 import numpy as np
 import pandas as pd
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import SamplerError, MCMCSampler
-
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import (
+    MCMCSampler,
+    SamplerError,
+    _sampling_convenience_dump,
+    signal_wrapper,
+)
 
 ConvergenceInputs = namedtuple(
     "ConvergenceInputs",
@@ -81,7 +83,7 @@ class Ptemcee(MCMCSampler):
         the Gelman-Rubin statistic).
     min_tau: int, (1)
         A minimum tau (autocorrelation time) to accept.
-    check_point_deltaT: float, (600)
+    check_point_delta_t: float, (600)
         The period with which to checkpoint (in seconds).
     threads: int, (1)
         If threads > 1, a MultiPool object is setup and used.
@@ -163,7 +165,7 @@ class Ptemcee(MCMCSampler):
         gradient_mean_log_posterior=0.1,
         Q_tol=1.02,
         min_tau=1,
-        check_point_deltaT=600,
+        check_point_delta_t=600,
         threads=1,
         exit_code=77,
         plot=False,
@@ -173,7 +175,7 @@ class Ptemcee(MCMCSampler):
         niterations_per_check=5,
         log10beta_min=None,
         verbose=True,
-        **kwargs
+        **kwargs,
     ):
         super(Ptemcee, self).__init__(
             likelihood=likelihood,
@@ -184,25 +186,18 @@ class Ptemcee(MCMCSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
-            **kwargs
+            **kwargs,
         )
 
         self.nwalkers = self.sampler_init_kwargs["nwalkers"]
         self.ntemps = self.sampler_init_kwargs["ntemps"]
         self.max_steps = 500
 
-        # Setup up signal handling
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
         # Checkpointing inputs
         self.resume = resume
-        self.check_point_deltaT = check_point_deltaT
+        self.check_point_delta_t = check_point_delta_t
         self.check_point_plot = check_point_plot
-        self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
-            self.outdir, self.label
-        )
+        self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle"
 
         # Store convergence checking inputs in a named tuple
         convergence_inputs_dict = dict(
@@ -223,7 +218,7 @@ class Ptemcee(MCMCSampler):
             niterations_per_check=niterations_per_check,
         )
         self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
-        logger.info("Using convergence inputs: {}".format(self.convergence_inputs))
+        logger.info(f"Using convergence inputs: {self.convergence_inputs}")
 
         # Check if threads was given as an equivalent arg
         if threads == 1:
@@ -239,32 +234,50 @@ class Ptemcee(MCMCSampler):
         self.pos0 = pos0
 
         self._periodic = [
-            self.priors[key].boundary == "periodic" for key in self.search_parameter_keys
+            self.priors[key].boundary == "periodic"
+            for key in self.search_parameter_keys
         ]
         self.priors.sample()
-        self._minima = np.array([
-            self.priors[key].minimum for key in self.search_parameter_keys
-        ])
-        self._range = np.array([
-            self.priors[key].maximum for key in self.search_parameter_keys
-        ]) - self._minima
+        self._minima = np.array(
+            [self.priors[key].minimum for key in self.search_parameter_keys]
+        )
+        self._range = (
+            np.array([self.priors[key].maximum for key in self.search_parameter_keys])
+            - self._minima
+        )
 
         self.log10beta_min = log10beta_min
         if self.log10beta_min is not None:
             betas = np.logspace(0, self.log10beta_min, self.ntemps)
-            logger.warning("Using betas {}".format(betas))
+            logger.warning(f"Using betas {betas}")
             self.kwargs["betas"] = betas
         self.verbose = verbose
 
+        self.iteration = 0
+        self.chain_array = self.get_zero_chain_array()
+        self.log_likelihood_array = self.get_zero_array()
+        self.log_posterior_array = self.get_zero_array()
+        self.beta_list = list()
+        self.tau_list = list()
+        self.tau_list_n = list()
+        self.Q_list = list()
+        self.time_per_check = list()
+
+        self.nburn = np.nan
+        self.thin = np.nan
+        self.tau_int = np.nan
+        self.nsamples_effective = 0
+        self.discard = 0
+
     @property
     def sampler_function_kwargs(self):
-        """ Kwargs passed to samper.sampler() """
+        """Kwargs passed to samper.sampler()"""
         keys = ["adapt", "swap_ratios"]
         return {key: self.kwargs[key] for key in keys}
 
     @property
     def sampler_init_kwargs(self):
-        """ Kwargs passed to initialize ptemcee.Sampler() """
+        """Kwargs passed to initialize ptemcee.Sampler()"""
         return {
             key: value
             for key, value in self.kwargs.items()
@@ -272,14 +285,15 @@ class Ptemcee(MCMCSampler):
         }
 
     def _translate_kwargs(self, kwargs):
-        """ Translate kwargs """
+        """Translate kwargs"""
+        kwargs = super()._translate_kwargs(kwargs)
         if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["nwalkers"] = kwargs.pop(equiv)
 
     def get_pos0_from_prior(self):
-        """ Draw the initial positions from the prior
+        """Draw the initial positions from the prior
 
         Returns
         =======
@@ -288,16 +302,15 @@ class Ptemcee(MCMCSampler):
 
         """
         logger.info("Generating pos0 samples")
-        return np.array([
+        return np.array(
             [
-                self.get_random_draw_from_prior()
-                for _ in range(self.nwalkers)
+                [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+                for _ in range(self.kwargs["ntemps"])
             ]
-            for _ in range(self.kwargs["ntemps"])
-        ])
+        )
 
     def get_pos0_from_minimize(self, minimize_list=None):
-        """ Draw the initial positions using an initial minimization step
+        """Draw the initial positions using an initial minimization step
 
         See pos0 in the class initialization for details.
 
@@ -318,12 +331,12 @@ class Ptemcee(MCMCSampler):
         else:
             pos0 = np.array(self.get_pos0_from_prior())
 
-        logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list))
+        logger.info(f"Attempting to set pos0 for {minimize_list} from minimize")
 
         likelihood_copy = copy.copy(self.likelihood)
 
         def neg_log_like(params):
-            """ Internal function to minimize """
+            """Internal function to minimize"""
             likelihood_copy.parameters.update(
                 {key: val for key, val in zip(minimize_list, params)}
             )
@@ -360,9 +373,7 @@ class Ptemcee(MCMCSampler):
         for i, key in enumerate(minimize_list):
             pos0_min = np.min(success[:, i])
             pos0_max = np.max(success[:, i])
-            logger.info(
-                "Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
-            )
+            logger.info(f"Initialize {key} walkers from {pos0_min}->{pos0_max}")
             j = self.search_parameter_keys.index(key)
             pos0[:, :, j] = np.random.uniform(
                 pos0_min,
@@ -375,9 +386,8 @@ class Ptemcee(MCMCSampler):
         if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim):
             raise ValueError(
                 "Shape of starting array should be (ntemps, nwalkers, ndim). "
-                "In this case that is ({}, {}, {}), got {}".format(
-                    self.ntemps, self.nwalkers, self.ndim, self.pos0.shape
-                )
+                f"In this case that is ({self.ntemps}, {self.nwalkers}, "
+                f"{self.ndim}), got {self.pos0.shape}"
             )
         else:
             return self.pos0
@@ -395,12 +405,13 @@ class Ptemcee(MCMCSampler):
         return self.get_pos0_from_array()
 
     def setup_sampler(self):
-        """ Either initialize the sampler or read in the resume file """
+        """Either initialize the sampler or read in the resume file"""
         import ptemcee
 
         if os.path.isfile(self.resume_file) and self.resume is True:
             import dill
-            logger.info("Resume data {} found".format(self.resume_file))
+
+            logger.info(f"Resume data {self.resume_file} found")
             with open(self.resume_file, "rb") as file:
                 data = dill.load(file)
 
@@ -422,9 +433,7 @@ class Ptemcee(MCMCSampler):
             self.sampler.pool = self.pool
             self.sampler.threads = self.threads
 
-            logger.info(
-                "Resuming from previous run with time={}".format(self.iteration)
-            )
+            logger.info(f"Resuming from previous run with time={self.iteration}")
 
         else:
             # Initialize the PTSampler
@@ -433,32 +442,29 @@ class Ptemcee(MCMCSampler):
                     dim=self.ndim,
                     logl=self.log_likelihood,
                     logp=self.log_prior,
-                    **self.sampler_init_kwargs
+                    **self.sampler_init_kwargs,
                 )
             else:
                 self.sampler = ptemcee.Sampler(
                     dim=self.ndim,
                     logl=do_nothing_function,
                     logp=do_nothing_function,
-                    pool=self.pool,
                     threads=self.threads,
-                    **self.sampler_init_kwargs
+                    **self.sampler_init_kwargs,
                 )
 
-                self.sampler._likeprior = LikePriorEvaluator(
-                    self.search_parameter_keys, use_ratio=self.use_ratio
-                )
+            self.sampler._likeprior = LikePriorEvaluator()
 
             # Initialize storing results
             self.iteration = 0
             self.chain_array = self.get_zero_chain_array()
             self.log_likelihood_array = self.get_zero_array()
             self.log_posterior_array = self.get_zero_array()
-            self.beta_list = []
-            self.tau_list = []
-            self.tau_list_n = []
-            self.Q_list = []
-            self.time_per_check = []
+            self.beta_list = list()
+            self.tau_list = list()
+            self.tau_list_n = list()
+            self.Q_list = list()
+            self.time_per_check = list()
             self.pos0 = self.get_pos0()
 
         return self.sampler
@@ -470,7 +476,7 @@ class Ptemcee(MCMCSampler):
         return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
 
     def get_pos0(self):
-        """ Master logic for setting pos0 """
+        """Master logic for setting pos0"""
         if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
             return self.get_pos0_from_prior()
         elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
@@ -482,52 +488,55 @@ class Ptemcee(MCMCSampler):
         elif isinstance(self.pos0, dict):
             return self.get_pos0_from_dict()
         else:
-            raise SamplerError("pos0={} not implemented".format(self.pos0))
+            raise SamplerError(f"pos0={self.pos0} not implemented")
 
-    def setup_pool(self):
-        """ If threads > 1, setup a MultiPool, else run in serial mode """
-        if self.threads > 1:
-            import schwimmbad
-
-            logger.info("Creating MultiPool with {} processes".format(self.threads))
-            self.pool = schwimmbad.MultiPool(
-                self.threads, initializer=init, initargs=(self.likelihood, self.priors)
-            )
-        else:
-            self.pool = None
+    def _close_pool(self):
+        if getattr(self.sampler, "pool", None) is not None:
+            self.sampler.pool = None
+        if "pool" in self.result.sampler_kwargs:
+            del self.result.sampler_kwargs["pool"]
+        super(Ptemcee, self)._close_pool()
 
+    @signal_wrapper
     def run_sampler(self):
-        self.setup_pool()
+        self._setup_pool()
         sampler = self.setup_sampler()
 
         t0 = datetime.datetime.now()
         logger.info("Starting to sample")
         while True:
             for (pos0, log_posterior, log_likelihood) in sampler.sample(
-                    self.pos0, storechain=False,
-                    iterations=self.convergence_inputs.niterations_per_check,
-                    **self.sampler_function_kwargs):
-                pos0[:, :, self._periodic] = np.mod(
-                    pos0[:, :, self._periodic] - self._minima[self._periodic],
-                    self._range[self._periodic]
-                ) + self._minima[self._periodic]
+                self.pos0,
+                storechain=False,
+                iterations=self.convergence_inputs.niterations_per_check,
+                **self.sampler_function_kwargs,
+            ):
+                pos0[:, :, self._periodic] = (
+                    np.mod(
+                        pos0[:, :, self._periodic] - self._minima[self._periodic],
+                        self._range[self._periodic],
+                    )
+                    + self._minima[self._periodic]
+                )
 
             if self.iteration == self.chain_array.shape[1]:
-                self.chain_array = np.concatenate((
-                    self.chain_array, self.get_zero_chain_array()), axis=1)
-                self.log_likelihood_array = np.concatenate((
-                    self.log_likelihood_array, self.get_zero_array()),
-                    axis=2)
-                self.log_posterior_array = np.concatenate((
-                    self.log_posterior_array, self.get_zero_array()),
-                    axis=2)
+                self.chain_array = np.concatenate(
+                    (self.chain_array, self.get_zero_chain_array()), axis=1
+                )
+                self.log_likelihood_array = np.concatenate(
+                    (self.log_likelihood_array, self.get_zero_array()), axis=2
+                )
+                self.log_posterior_array = np.concatenate(
+                    (self.log_posterior_array, self.get_zero_array()), axis=2
+                )
 
             self.pos0 = pos0
             self.chain_array[:, self.iteration, :] = pos0[0, :, :]
             self.log_likelihood_array[:, :, self.iteration] = log_likelihood
             self.log_posterior_array[:, :, self.iteration] = log_posterior
             self.mean_log_posterior = np.mean(
-                self.log_posterior_array[:, :, :self. iteration], axis=1)
+                self.log_posterior_array[:, :, : self.iteration], axis=1
+            )
 
             # Calculate time per iteration
             self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
@@ -537,15 +546,13 @@ class Ptemcee(MCMCSampler):
 
             # Calculate minimum iteration step to discard
             minimum_iteration = get_minimum_stable_itertion(
-                self.mean_log_posterior,
-                frac=self.convergence_inputs.mean_logl_frac
+                self.mean_log_posterior, frac=self.convergence_inputs.mean_logl_frac
             )
-            logger.debug("Minimum iteration = {}".format(minimum_iteration))
+            logger.debug(f"Minimum iteration = {minimum_iteration}")
 
             # Calculate the maximum discard number
             discard_max = np.max(
-                [self.convergence_inputs.burn_in_fixed_discard,
-                 minimum_iteration]
+                [self.convergence_inputs.burn_in_fixed_discard, minimum_iteration]
             )
 
             if self.iteration > discard_max + self.nwalkers:
@@ -565,7 +572,7 @@ class Ptemcee(MCMCSampler):
                 self.nsamples_effective,
             ) = check_iteration(
                 self.iteration,
-                self.chain_array[:, self.discard:self.iteration, :],
+                self.chain_array[:, self.discard : self.iteration, :],
                 sampler,
                 self.convergence_inputs,
                 self.search_parameter_keys,
@@ -588,7 +595,7 @@ class Ptemcee(MCMCSampler):
             else:
                 last_checkpoint_s = np.sum(self.time_per_check)
 
-            if last_checkpoint_s > self.check_point_deltaT:
+            if last_checkpoint_s > self.check_point_delta_t:
                 self.write_current_state(plot=self.check_point_plot)
 
         # Run a final checkpoint to update the plots and samples
@@ -609,9 +616,14 @@ class Ptemcee(MCMCSampler):
         self.result.discard = self.discard
 
         log_evidence, log_evidence_err = compute_evidence(
-            sampler, self.log_likelihood_array, self.outdir,
-            self.label, self.discard, self.nburn,
-            self.thin, self.iteration,
+            sampler,
+            self.log_likelihood_array,
+            self.outdir,
+            self.label,
+            self.discard,
+            self.nburn,
+            self.thin,
+            self.iteration,
         )
         self.result.log_evidence = log_evidence
         self.result.log_evidence_err = log_evidence_err
@@ -620,21 +632,10 @@ class Ptemcee(MCMCSampler):
             seconds=np.sum(self.time_per_check)
         )
 
-        if self.pool:
-            self.pool.close()
+        self._close_pool()
 
         return self.result
 
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        logger.warning("Run terminated with signal {}".format(signum))
-        if getattr(self, "pool", None) or self.threads == 1:
-            self.write_current_state(plot=False)
-        if getattr(self, "pool", None):
-            logger.info("Closing pool")
-            self.pool.close()
-        logger.info("Exit on signal {}".format(self.exit_code))
-        sys.exit(self.exit_code)
-
     def write_current_state(self, plot=True):
         check_directory_exists_and_if_not_mkdir(self.outdir)
         checkpoint(
@@ -672,7 +673,7 @@ class Ptemcee(MCMCSampler):
                     self.discard,
                 )
             except Exception as e:
-                logger.info("Walkers plot failed with exception {}".format(e))
+                logger.info(f"Walkers plot failed with exception {e}")
 
             try:
                 # Generate the tau plot diagnostic if DEBUG
@@ -687,7 +688,7 @@ class Ptemcee(MCMCSampler):
                         self.convergence_inputs.autocorr_tau,
                     )
             except Exception as e:
-                logger.info("tau plot failed with exception {}".format(e))
+                logger.info(f"tau plot failed with exception {e}")
 
             try:
                 plot_mean_log_posterior(
@@ -696,7 +697,7 @@ class Ptemcee(MCMCSampler):
                     self.label,
                 )
             except Exception as e:
-                logger.info("mean_logl plot failed with exception {}".format(e))
+                logger.info(f"mean_logl plot failed with exception {e}")
 
 
 def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10):
@@ -728,7 +729,7 @@ def check_iteration(
     mean_log_posterior,
     verbose=True,
 ):
-    """ Per-iteration logic to calculate the convergence check
+    """Per-iteration logic to calculate the convergence check
 
     Parameters
     ==========
@@ -780,8 +781,17 @@ def check_iteration(
     if np.isnan(tau) or np.isinf(tau):
         if verbose:
             print_progress(
-                iteration, sampler, time_per_check, np.nan, np.nan,
-                np.nan, np.nan, np.nan, False, convergence_inputs, Q,
+                iteration,
+                sampler,
+                time_per_check,
+                np.nan,
+                np.nan,
+                np.nan,
+                np.nan,
+                np.nan,
+                False,
+                convergence_inputs,
+                Q,
             )
         return False, np.nan, np.nan, np.nan, np.nan
 
@@ -796,45 +806,47 @@ def check_iteration(
 
     # Calculate convergence boolean
     converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
-    logger.debug("Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}"
-                 .format(Q < ci.Q_tol, ci.nsamples < nsamples_effective))
+    logger.debug(
+        f"Convergence: Q<Q_tol={Q < ci.Q_tol}, "
+        f"nsamples<nsamples_effective={ci.nsamples < nsamples_effective}"
+    )
 
     GRAD_WINDOW_LENGTH = nwalkers + 1
     nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int])
     lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check])
-    check_taus = np.array(tau_list[lower_tau_index :])
+    check_taus = np.array(tau_list[lower_tau_index:])
     if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
-        gradient_tau = get_max_gradient(
-            check_taus, axis=0, window_length=11)
+        gradient_tau = get_max_gradient(check_taus, axis=0, window_length=11)
 
         if gradient_tau < ci.gradient_tau:
             logger.debug(
-                "tau usable as {} < gradient_tau={}"
-                .format(gradient_tau, ci.gradient_tau)
+                f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}"
             )
             tau_usable = True
         else:
             logger.debug(
-                "tau not usable as {} > gradient_tau={}"
-                .format(gradient_tau, ci.gradient_tau)
+                f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}"
             )
             tau_usable = False
 
         check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:]
         gradient_mean_log_posterior = get_max_gradient(
-            check_mean_log_posterior, axis=1, window_length=GRAD_WINDOW_LENGTH,
-            smooth=True)
+            check_mean_log_posterior,
+            axis=1,
+            window_length=GRAD_WINDOW_LENGTH,
+            smooth=True,
+        )
 
         if gradient_mean_log_posterior < ci.gradient_mean_log_posterior:
             logger.debug(
-                "tau usable as {} < gradient_mean_log_posterior={}"
-                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+                f"tau usable as {gradient_mean_log_posterior} < "
+                f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}"
             )
             tau_usable *= True
         else:
             logger.debug(
-                "tau not usable as {} > gradient_mean_log_posterior={}"
-                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+                f"tau not usable as {gradient_mean_log_posterior} > "
+                f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}"
             )
             tau_usable = False
 
@@ -864,7 +876,7 @@ def check_iteration(
             gradient_mean_log_posterior,
             tau_usable,
             convergence_inputs,
-            Q
+            Q,
         )
     stop = converged and tau_usable
     return stop, nburn, thin, tau_int, nsamples_effective
@@ -872,13 +884,14 @@ def check_iteration(
 
 def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False):
     from scipy.signal import savgol_filter
+
     if smooth:
-        x = savgol_filter(
-            x, axis=axis, window_length=window_length, polyorder=3
+        x = savgol_filter(x, axis=axis, window_length=window_length, polyorder=3)
+    return np.max(
+        savgol_filter(
+            x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1
         )
-    return np.max(savgol_filter(
-        x, axis=axis, window_length=window_length, polyorder=polyorder,
-        deriv=1))
+    )
 
 
 def get_Q_convergence(samples):
@@ -887,7 +900,7 @@ def get_Q_convergence(samples):
         W = np.mean(np.var(samples, axis=1), axis=0)
         per_walker_mean = np.mean(samples, axis=1)
         mean = np.mean(per_walker_mean, axis=0)
-        B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0)
+        B = nsteps / (nwalkers - 1.0) * np.sum((per_walker_mean - mean) ** 2, axis=0)
         Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
         Q_per_dim = np.sqrt(Vhat / W)
         return np.max(Q_per_dim)
@@ -910,16 +923,18 @@ def print_progress(
 ):
     # Setup acceptance string
     acceptance = sampler.acceptance_fraction[0, :]
-    acceptance_str = "{:1.2f}-{:1.2f}".format(np.min(acceptance), np.max(acceptance))
+    acceptance_str = f"{np.min(acceptance):1.2f}-{np.max(acceptance):1.2f}"
 
     # Setup tswap acceptance string
     tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
-    tswap_acceptance_str = "{:1.2f}-{:1.2f}".format(
-        np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction)
-    )
+    tswap_acceptance_str = f"{np.min(tswap_acceptance_fraction):1.2f}-{np.max(tswap_acceptance_fraction):1.2f}"
 
     ave_time_per_check = np.mean(time_per_check[-3:])
-    time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
+    time_left = (
+        (convergence_inputs.nsamples - nsamples_effective)
+        * ave_time_per_check
+        / samples_per_check
+    )
     if time_left > 0:
         time_left = str(datetime.timedelta(seconds=int(time_left)))
     else:
@@ -927,46 +942,44 @@ def print_progress(
 
     sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
 
-    tau_str = "{}(+{:0.2f},+{:0.2f})".format(
-        tau_int, gradient_tau, gradient_mean_log_posterior
-    )
+    tau_str = f"{tau_int}(+{gradient_tau:0.2f},+{gradient_mean_log_posterior:0.2f})"
 
     if tau_usable:
-        tau_str = "={}".format(tau_str)
+        tau_str = f"={tau_str}"
     else:
-        tau_str = "!{}".format(tau_str)
+        tau_str = f"!{tau_str}"
 
-    Q_str = "{:0.2f}".format(Q)
+    Q_str = f"{Q:0.2f}"
 
-    evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
+    evals_per_check = (
+        sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
+    )
 
-    ncalls = "{:1.1e}".format(
-        convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps)
-    eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
+    approximate_ncalls = (
+        convergence_inputs.niterations_per_check
+        * iteration
+        * sampler.nwalkers
+        * sampler.ntemps
+    )
+    ncalls = f"{approximate_ncalls:1.1e}"
+    eval_timing = f"{1000.0 * ave_time_per_check / evals_per_check:1.2f}ms/ev"
 
     try:
         print(
-            "{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}".format(
-                iteration,
-                str(sampling_time).split(".")[0],
-                ncalls,
-                acceptance_str,
-                tswap_acceptance_str,
-                nsamples_effective,
-                convergence_inputs.nsamples,
-                tau_str,
-                Q_str,
-                eval_timing,
-            ),
+            f"{iteration}|{str(sampling_time).split('.')[0]}|nc:{ncalls}|"
+            f"a0:{acceptance_str}|swp:{tswap_acceptance_str}|"
+            f"n:{nsamples_effective}<{convergence_inputs.nsamples}|t{tau_str}|"
+            f"q:{Q_str}|{eval_timing}",
             flush=True,
         )
     except OSError as e:
-        logger.debug("Failed to print iteration due to :{}".format(e))
+        logger.debug(f"Failed to print iteration due to :{e}")
 
 
 def calculate_tau_array(samples, search_parameter_keys, ci):
-    """ Compute ACT tau for 0-temperature chains """
+    """Compute ACT tau for 0-temperature chains"""
     import emcee
+
     nwalkers, nsteps, ndim = samples.shape
     tau_array = np.zeros((nwalkers, ndim)) + np.inf
     if nsteps > 1:
@@ -976,7 +989,8 @@ def calculate_tau_array(samples, search_parameter_keys, ci):
                     continue
                 try:
                     tau_array[ii, jj] = emcee.autocorr.integrated_time(
-                        samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
+                        samples[ii, :, jj], c=ci.autocorr_c, tol=0
+                    )[0]
                 except emcee.autocorr.AutocorrError:
                     tau_array[ii, jj] = np.inf
     return tau_array
@@ -1004,21 +1018,24 @@ def checkpoint(
     time_per_check,
 ):
     import dill
+
     logger.info("Writing checkpoint and diagnostics")
     ndim = sampler.dim
 
     # Store the samples if possible
     if nsamples_effective > 0:
-        filename = "{}/{}_samples.txt".format(outdir, label)
-        samples = np.array(chain_array)[:, discard + nburn : iteration : thin, :].reshape(
-            (-1, ndim)
-        )
+        filename = f"{outdir}/{label}_samples.txt"
+        samples = np.array(chain_array)[
+            :, discard + nburn : iteration : thin, :
+        ].reshape((-1, ndim))
         df = pd.DataFrame(samples, columns=search_parameter_keys)
         df.to_csv(filename, index=False, header=True, sep=" ")
 
     # Pickle the resume artefacts
-    sampler_copy = copy.copy(sampler)
-    del sampler_copy.pool
+    pool = sampler.pool
+    sampler.pool = None
+    sampler_copy = copy.deepcopy(sampler)
+    sampler.pool = pool
 
     data = dict(
         iteration=iteration,
@@ -1040,10 +1057,10 @@ def checkpoint(
     logger.info("Finished writing checkpoint")
 
 
-def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
-                 discard=0):
-    """ Method to plot the trace of the walkers in an ensemble MCMC plot """
+def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard=0):
+    """Method to plot the trace of the walkers in an ensemble MCMC plot"""
     import matplotlib.pyplot as plt
+
     nwalkers, nsteps, ndim = walkers.shape
     if np.isnan(nburn):
         nburn = nsteps
@@ -1051,51 +1068,65 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
         thin = 1
     idxs = np.arange(nsteps)
     fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
-    scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.1,)
+    scatter_kwargs = dict(
+        lw=0,
+        marker="o",
+        markersize=1,
+        alpha=0.1,
+    )
 
     # Plot the fixed burn-in
     if discard > 0:
         for i, (ax, axh) in enumerate(axes):
             ax.plot(
-                idxs[: discard],
-                walkers[:, : discard, i].T,
+                idxs[:discard],
+                walkers[:, :discard, i].T,
                 color="gray",
-                **scatter_kwargs
+                **scatter_kwargs,
             )
 
     # Plot the burn-in
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[discard: discard + nburn + 1],
-            walkers[:, discard: discard + nburn + 1, i].T,
+            idxs[discard : discard + nburn + 1],
+            walkers[:, discard : discard + nburn + 1, i].T,
             color="C1",
-            **scatter_kwargs
+            **scatter_kwargs,
         )
 
     # Plot the thinned posterior samples
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[discard + nburn::thin],
-            walkers[:, discard + nburn::thin, i].T,
+            idxs[discard + nburn :: thin],
+            walkers[:, discard + nburn :: thin, i].T,
             color="C0",
-            **scatter_kwargs
+            **scatter_kwargs,
+        )
+        axh.hist(
+            walkers[:, discard + nburn :: thin, i].reshape((-1)), bins=50, alpha=0.8
         )
-        axh.hist(walkers[:, discard + nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
 
     for i, (ax, axh) in enumerate(axes):
         axh.set_xlabel(parameter_labels[i])
         ax.set_ylabel(parameter_labels[i])
 
     fig.tight_layout()
-    filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
+    filename = f"{outdir}/{label}_checkpoint_trace.png"
     fig.savefig(filename)
     plt.close(fig)
 
 
 def plot_tau(
-    tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau,
+    tau_list_n,
+    tau_list,
+    search_parameter_keys,
+    outdir,
+    label,
+    tau,
+    autocorr_tau,
 ):
     import matplotlib.pyplot as plt
+
     fig, ax = plt.subplots()
     for i, key in enumerate(search_parameter_keys):
         ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
@@ -1103,7 +1134,7 @@ def plot_tau(
     ax.set_ylabel(r"$\langle \tau \rangle$")
     ax.legend()
     fig.tight_layout()
-    fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
+    fig.savefig(f"{outdir}/{label}_checkpoint_tau.png")
     plt.close(fig)
 
 
@@ -1119,17 +1150,30 @@ def plot_mean_log_posterior(mean_log_posterior, outdir, label):
     fig, ax = plt.subplots()
     idxs = np.arange(nsteps)
     ax.plot(idxs, mean_log_posterior.T)
-    ax.set(xlabel="Iteration", ylabel=r"$\langle\mathrm{log-posterior}\rangle$",
-           ylim=(ymin, ymax))
+    ax.set(
+        xlabel="Iteration",
+        ylabel=r"$\langle\mathrm{log-posterior}\rangle$",
+        ylim=(ymin, ymax),
+    )
     fig.tight_layout()
-    fig.savefig("{}/{}_checkpoint_meanlogposterior.png".format(outdir, label))
+    fig.savefig(f"{outdir}/{label}_checkpoint_meanlogposterior.png")
     plt.close(fig)
 
 
-def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nburn, thin,
-                     iteration, make_plots=True):
-    """ Computes the evidence using thermodynamic integration """
+def compute_evidence(
+    sampler,
+    log_likelihood_array,
+    outdir,
+    label,
+    discard,
+    nburn,
+    thin,
+    iteration,
+    make_plots=True,
+):
+    """Computes the evidence using thermodynamic integration"""
     import matplotlib.pyplot as plt
+
     betas = sampler.betas
     # We compute the evidence without the burnin samples, but we do not thin
     lnlike = log_likelihood_array[:, :, discard + nburn : iteration]
@@ -1141,7 +1185,7 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur
     if any(np.isinf(mean_lnlikes)):
         logger.warning(
             "mean_lnlikes contains inf: recalculating without"
-            " the {} infs".format(len(betas[np.isinf(mean_lnlikes)]))
+            f" the {len(betas[np.isinf(mean_lnlikes)])} infs"
         )
         idxs = np.isinf(mean_lnlikes)
         mean_lnlikes = mean_lnlikes[~idxs]
@@ -1165,33 +1209,23 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur
 
         ax2.semilogx(min_betas, evidence, "-o")
         ax2.set_ylabel(
-            r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$",
+            r"$\int_{\beta_{min}}^{\beta=1}"
+            + r"\langle \log(\mathcal{L})\rangle d\beta$",
             size=16,
         )
         ax2.set_xlabel(r"$\beta_{min}$")
         plt.tight_layout()
-        fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
+        fig.savefig(f"{outdir}/{label}_beta_lnl.png")
         plt.close(fig)
 
     return lnZ, lnZerr
 
 
 def do_nothing_function():
-    """ This is a do-nothing function, we overwrite the likelihood and prior elsewhere """
+    """This is a do-nothing function, we overwrite the likelihood and prior elsewhere"""
     pass
 
 
-likelihood = None
-priors = None
-
-
-def init(likelihood_in, priors_in):
-    global likelihood
-    global priors
-    likelihood = likelihood_in
-    priors = priors_in
-
-
 class LikePriorEvaluator(object):
     """
     This class is copied and modified from ptemcee.LikePriorEvaluator, see
@@ -1203,38 +1237,43 @@ class LikePriorEvaluator(object):
 
     """
 
-    def __init__(self, search_parameter_keys, use_ratio=False):
-        self.search_parameter_keys = search_parameter_keys
-        self.use_ratio = use_ratio
+    def __init__(self):
         self.periodic_set = False
 
     def _setup_periodic(self):
+        priors = _sampling_convenience_dump.priors
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
         self._periodic = [
-            priors[key].boundary == "periodic" for key in self.search_parameter_keys
+            priors[key].boundary == "periodic" for key in search_parameter_keys
         ]
         priors.sample()
-        self._minima = np.array([
-            priors[key].minimum for key in self.search_parameter_keys
-        ])
-        self._range = np.array([
-            priors[key].maximum for key in self.search_parameter_keys
-        ]) - self._minima
+        self._minima = np.array([priors[key].minimum for key in search_parameter_keys])
+        self._range = (
+            np.array([priors[key].maximum for key in search_parameter_keys])
+            - self._minima
+        )
         self.periodic_set = True
 
     def _wrap_periodic(self, array):
         if not self.periodic_set:
             self._setup_periodic()
-        array[self._periodic] = np.mod(
-            array[self._periodic] - self._minima[self._periodic],
-            self._range[self._periodic]
-        ) + self._minima[self._periodic]
+        array[self._periodic] = (
+            np.mod(
+                array[self._periodic] - self._minima[self._periodic],
+                self._range[self._periodic],
+            )
+            + self._minima[self._periodic]
+        )
         return array
 
     def logl(self, v_array):
-        parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)}
+        priors = _sampling_convenience_dump.priors
+        likelihood = _sampling_convenience_dump.likelihood
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
+        parameters = {key: v for key, v in zip(search_parameter_keys, v_array)}
         if priors.evaluate_constraints(parameters) > 0:
             likelihood.parameters.update(parameters)
-            if self.use_ratio:
+            if _sampling_convenience_dump.use_ratio:
                 return likelihood.log_likelihood() - likelihood.noise_log_likelihood()
             else:
                 return likelihood.log_likelihood()
@@ -1242,9 +1281,15 @@ class LikePriorEvaluator(object):
             return np.nan_to_num(-np.inf)
 
     def logp(self, v_array):
-        params = {key: t for key, t in zip(self.search_parameter_keys, v_array)}
+        priors = _sampling_convenience_dump.priors
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
+        params = {key: t for key, t in zip(search_parameter_keys, v_array)}
         return priors.ln_prob(params)
 
+    def call_emcee(self, theta):
+        ll, lp = self.__call__(theta)
+        return ll + lp, [ll, lp]
+
     def __call__(self, x):
         lp = self.logp(x)
         if np.isnan(lp):
diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py
index 49b86d7392ff4d21af698fe7b7034b25eaa7ac22..42279e018ed124cd117118e75949b60d74e3a302 100644
--- a/bilby/core/sampler/ptmcmc.py
+++ b/bilby/core/sampler/ptmcmc.py
@@ -1,11 +1,10 @@
-
 import glob
 import shutil
 
 import numpy as np
 
-from .base_sampler import MCMCSampler, SamplerNotInstalledError
 from ..utils import logger
+from .base_sampler import MCMCSampler, SamplerNotInstalledError, signal_wrapper
 
 
 class PTMCMCSampler(MCMCSampler):
@@ -42,29 +41,66 @@ class PTMCMCSampler(MCMCSampler):
 
     """
 
-    default_kwargs = {'p0': None, 'Niter': 2 * 10 ** 4 + 1, 'neff': 10 ** 4,
-                      'burn': 5 * 10 ** 3, 'verbose': True,
-                      'ladder': None, 'Tmin': 1, 'Tmax': None, 'Tskip': 100,
-                      'isave': 1000, 'thin': 1, 'covUpdate': 1000,
-                      'SCAMweight': 1, 'AMweight': 1, 'DEweight': 1,
-                      'HMCweight': 0, 'MALAweight': 0, 'NUTSweight': 0,
-                      'HMCstepsize': 0.1, 'HMCsteps': 300,
-                      'groups': None, 'custom_proposals': None,
-                      'loglargs': {}, 'loglkwargs': {}, 'logpargs': {},
-                      'logpkwargs': {}, 'logl_grad': None, 'logp_grad': None,
-                      'outDir': None}
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, burn_in_fraction=0.25, **kwargs):
-
-        super(PTMCMCSampler, self).__init__(likelihood=likelihood, priors=priors,
-                                            outdir=outdir, label=label, use_ratio=use_ratio,
-                                            plot=plot,
-                                            skip_import_verification=skip_import_verification,
-                                            **kwargs)
-
-        self.p0 = self.get_random_draw_from_prior()
+    default_kwargs = {
+        "p0": None,
+        "Niter": 2 * 10**4 + 1,
+        "neff": 10**4,
+        "burn": 5 * 10**3,
+        "verbose": True,
+        "ladder": None,
+        "Tmin": 1,
+        "Tmax": None,
+        "Tskip": 100,
+        "isave": 1000,
+        "thin": 1,
+        "covUpdate": 1000,
+        "SCAMweight": 1,
+        "AMweight": 1,
+        "DEweight": 1,
+        "HMCweight": 0,
+        "MALAweight": 0,
+        "NUTSweight": 0,
+        "HMCstepsize": 0.1,
+        "HMCsteps": 300,
+        "groups": None,
+        "custom_proposals": None,
+        "loglargs": {},
+        "loglkwargs": {},
+        "logpargs": {},
+        "logpkwargs": {},
+        "logl_grad": None,
+        "logp_grad": None,
+        "outDir": None,
+    }
+    hard_exit = True
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        **kwargs,
+    ):
+
+        super(PTMCMCSampler, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+
+        if self.kwargs["p0"] is None:
+            self.p0 = self.get_random_draw_from_prior()
+        else:
+            self.p0 = self.kwargs["p0"]
         self.likelihood = likelihood
         self.priors = priors
 
@@ -73,88 +109,103 @@ class PTMCMCSampler(MCMCSampler):
         # which forces `__name__.lower()
         external_sampler_name = self.__class__.__name__
         try:
-            self.external_sampler = __import__(external_sampler_name)
+            __import__(external_sampler_name)
         except (ImportError, SystemExit):
             raise SamplerNotInstalledError(
-                "Sampler {} is not installed on this system".format(external_sampler_name))
+                f"Sampler {external_sampler_name} is not installed on this system"
+            )
 
     def _translate_kwargs(self, kwargs):
-        if 'Niter' not in kwargs:
+        kwargs = super()._translate_kwargs(kwargs)
+        if "Niter" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['Niter'] = kwargs.pop(equiv)
-        if 'burn' not in kwargs:
+                    kwargs["Niter"] = kwargs.pop(equiv)
+        if "burn" not in kwargs:
             for equiv in self.nburn_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['burn'] = kwargs.pop(equiv)
+                    kwargs["burn"] = kwargs.pop(equiv)
 
     @property
     def custom_proposals(self):
-        return self.kwargs['custom_proposals']
+        return self.kwargs["custom_proposals"]
 
     @property
     def sampler_init_kwargs(self):
-        keys = ['groups',
-                'loglargs',
-                'logp_grad',
-                'logpkwargs',
-                'loglkwargs',
-                'logl_grad',
-                'logpargs',
-                'outDir',
-                'verbose']
+        keys = [
+            "groups",
+            "loglargs",
+            "logp_grad",
+            "logpkwargs",
+            "loglkwargs",
+            "logl_grad",
+            "logpargs",
+            "outDir",
+            "verbose",
+        ]
         init_kwargs = {key: self.kwargs[key] for key in keys}
-        if init_kwargs['outDir'] is None:
-            init_kwargs['outDir'] = '{}/ptmcmc_temp_{}/'.format(self.outdir, self.label)
+        if init_kwargs["outDir"] is None:
+            init_kwargs["outDir"] = f"{self.outdir}/ptmcmc_temp_{self.label}/"
         return init_kwargs
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['Niter',
-                'neff',
-                'Tmin',
-                'HMCweight',
-                'covUpdate',
-                'SCAMweight',
-                'ladder',
-                'burn',
-                'NUTSweight',
-                'AMweight',
-                'MALAweight',
-                'thin',
-                'HMCstepsize',
-                'isave',
-                'Tskip',
-                'HMCsteps',
-                'Tmax',
-                'DEweight']
+        keys = [
+            "Niter",
+            "neff",
+            "Tmin",
+            "HMCweight",
+            "covUpdate",
+            "SCAMweight",
+            "ladder",
+            "burn",
+            "NUTSweight",
+            "AMweight",
+            "MALAweight",
+            "thin",
+            "HMCstepsize",
+            "isave",
+            "Tskip",
+            "HMCsteps",
+            "Tmax",
+            "DEweight",
+        ]
         sampler_kwargs = {key: self.kwargs[key] for key in keys}
         return sampler_kwargs
 
     @staticmethod
     def _import_external_sampler():
         from PTMCMCSampler import PTMCMCSampler
+
         return PTMCMCSampler
 
+    @signal_wrapper
     def run_sampler(self):
         PTMCMCSampler = self._import_external_sampler()
-        sampler = PTMCMCSampler.PTSampler(ndim=self.ndim, logp=self.log_prior,
-                                          logl=self.log_likelihood, cov=np.eye(self.ndim),
-                                          **self.sampler_init_kwargs)
+        sampler = PTMCMCSampler.PTSampler(
+            ndim=self.ndim,
+            logp=self.log_prior,
+            logl=self.log_likelihood,
+            cov=np.eye(self.ndim),
+            **self.sampler_init_kwargs,
+        )
         if self.custom_proposals is not None:
             for proposal in self.custom_proposals:
-                logger.info('Adding {} to proposals with weight {}'.format(
-                    proposal, self.custom_proposals[proposal][1]))
-                sampler.addProposalToCycle(self.custom_proposals[proposal][0],
-                                           self.custom_proposals[proposal][1])
+                logger.info(
+                    f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}"
+                )
+                sampler.addProposalToCycle(
+                    self.custom_proposals[proposal][0],
+                    self.custom_proposals[proposal][1],
+                )
         sampler.sample(p0=self.p0, **self.sampler_function_kwargs)
         samples, meta, loglike = self.__read_in_data()
 
         self.calc_likelihood_count()
-        self.result.nburn = self.sampler_function_kwargs['burn']
-        self.result.samples = samples[self.result.nburn:]
-        self.meta_data['sampler_meta'] = meta
-        self.result.log_likelihood_evaluations = loglike[self.result.nburn:]
+        self.result.nburn = self.sampler_function_kwargs["burn"]
+        self.result.samples = samples[self.result.nburn :]
+        self.meta_data["sampler_meta"] = meta
+        self.result.log_likelihood_evaluations = loglike[self.result.nburn :]
         self.result.sampler_output = np.nan
         self.result.walkers = np.nan
         self.result.log_evidence = np.nan
@@ -162,30 +213,34 @@ class PTMCMCSampler(MCMCSampler):
         return self.result
 
     def __read_in_data(self):
-        """ Read the data stored by PTMCMC to disk """
-        temp_outDir = self.sampler_init_kwargs['outDir']
+        """Read the data stored by PTMCMC to disk"""
+        temp_outDir = self.sampler_init_kwargs["outDir"]
         try:
-            data = np.loadtxt('{}chain_1.txt'.format(temp_outDir))
+            data = np.loadtxt(f"{temp_outDir}chain_1.txt")
         except OSError:
-            data = np.loadtxt('{}chain_1.0.txt'.format(temp_outDir))
-        jumpfiles = glob.glob('{}/*jump.txt'.format(temp_outDir))
+            data = np.loadtxt(f"{temp_outDir}chain_1.0.txt")
+        jumpfiles = glob.glob(f"{temp_outDir}/*jump.txt")
         jumps = map(np.loadtxt, jumpfiles)
         samples = data[:, :-4]
         loglike = data[:, -3]
 
         jump_accept = {}
         for ct, j in enumerate(jumps):
-            label = jumpfiles[ct].split('/')[-1].split('_jump.txt')[0]
+            label = jumpfiles[ct].split("/")[-1].split("_jump.txt")[0]
             jump_accept[label] = j
-        PT_swap = {'swap_accept': data[:, -1]}
-        tot_accept = {'tot_accept': data[:, -2]}
-        log_post = {'log_post': data[:, -4]}
+        PT_swap = {"swap_accept": data[:, -1]}
+        tot_accept = {"tot_accept": data[:, -2]}
+        log_post = {"log_post": data[:, -4]}
         meta = {}
-        meta['tot_accept'] = tot_accept
-        meta['PT_swap'] = PT_swap
-        meta['proposals'] = jump_accept
-        meta['log_post'] = log_post
+        meta["tot_accept"] = tot_accept
+        meta["PT_swap"] = PT_swap
+        meta["proposals"] = jump_accept
+        meta["log_post"] = log_post
 
         shutil.rmtree(temp_outDir)
 
         return samples, meta, loglike
+
+    def write_current_state(self):
+        """TODO: implement a checkpointing method"""
+        pass
diff --git a/bilby/core/sampler/pymc.py b/bilby/core/sampler/pymc.py
new file mode 100644
index 0000000000000000000000000000000000000000..95a57dd4e94866135b50053d15f9d3e700cd8d4b
--- /dev/null
+++ b/bilby/core/sampler/pymc.py
@@ -0,0 +1,1004 @@
+from distutils.version import StrictVersion
+
+import numpy as np
+
+from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
+from ..likelihood import (
+    ExponentialLikelihood,
+    GaussianLikelihood,
+    PoissonLikelihood,
+    StudentTLikelihood,
+)
+from ..prior import Cosine, DeltaFunction, MultivariateGaussian, PowerLaw, Sine
+from ..utils import derivatives, infer_args_from_method
+from .base_sampler import MCMCSampler
+
+
+class Pymc(MCMCSampler):
+    """bilby wrapper of the PyMC sampler (https://www.pymc.io/)
+
+    All keyword arguments (i.e., the kwargs) passed to `run_sampler` will be
+    propapated to `pymc.sample` where appropriate, see documentation for that
+    class for further help. Under Other Parameters, we list commonly used
+    kwargs and the bilby, or where appropriate, PyMC defaults.
+
+    Parameters
+    ==========
+    draws: int, (1000)
+        The number of sample draws from the posterior per chain.
+    chains: int, (2)
+        The number of independent MCMC chains to run.
+    cores: int, (1)
+        The number of CPU cores to use.
+    tune: int, (500)
+        The number of tuning (or burn-in) samples per chain.
+    discard_tuned_samples: bool, True
+        Set whether to automatically discard the tuning samples from the final
+        chains.
+    step: str, dict
+        Provide a step method name, or dictionary of step method names keyed to
+        particular variable names (these are case insensitive). If passing a
+        dictionary of methods, the values keyed on particular variables can be
+        lists of methods to form compound steps. If no method is provided for
+        any particular variable then PyMC will automatically decide upon a
+        default, with the first option being the NUTS sampler. The currently
+        allowed methods are 'NUTS', 'HamiltonianMC', 'Metropolis',
+        'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and
+        'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC step
+        method function itself here as it is outside of the model context
+        manager.
+    step_kwargs: dict
+        Options for steps methods other than NUTS. The dictionary is keyed on
+        lowercase step method names with values being dictionaries of keywords
+        for the given step method.
+
+    """
+
+    default_kwargs = dict(
+        draws=500,
+        step=None,
+        init="auto",
+        n_init=200000,
+        initvals=None,
+        trace=None,
+        chain_idx=0,
+        chains=2,
+        cores=1,
+        tune=500,
+        progressbar=True,
+        model=None,
+        random_seed=None,
+        discard_tuned_samples=True,
+        compute_convergence_checks=True,
+        nuts_kwargs=None,
+        step_kwargs=None,
+    )
+
+    default_nuts_kwargs = dict(
+        target_accept=None,
+        max_treedepth=None,
+        step_scale=None,
+        Emax=None,
+        gamma=None,
+        k=None,
+        t0=None,
+        adapt_step_size=None,
+        early_max_treedepth=None,
+        scaling=None,
+        is_cov=None,
+        potential=None,
+    )
+
+    default_kwargs.update(default_nuts_kwargs)
+
+    sampling_seed_key = "random_seed"
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        **kwargs,
+    ):
+        # add default step kwargs
+        _, STEP_METHODS, _ = self._import_external_sampler()
+        self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
+        self.default_kwargs.update(self.default_step_kwargs)
+
+        super(Pymc, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+        self.draws = self._kwargs["draws"]
+        self.chains = self._kwargs["chains"]
+
+    @staticmethod
+    def _import_external_sampler():
+        import pymc
+        from pymc.aesaraf import floatX
+        from pymc.step_methods import STEP_METHODS
+
+        return pymc, STEP_METHODS, floatX
+
+    @staticmethod
+    def _import_aesara():
+        import aesara  # noqa
+        import aesara.tensor as tt
+        from aesara.compile.ops import as_op  # noqa
+
+        return aesara, tt, as_op
+
+    def _verify_parameters(self):
+        """
+        Change `_verify_parameters()` to just pass, i.e., don't try and
+        evaluate the likelihood for PyMC.
+        """
+        pass
+
+    def _verify_use_ratio(self):
+        """
+        Change `_verify_use_ratio() to just pass.
+        """
+        pass
+
+    def setup_prior_mapping(self):
+        """
+        Set the mapping between predefined bilby priors and the equivalent
+        PyMC distributions.
+        """
+
+        prior_map = {}
+        self.prior_map = prior_map
+
+        # predefined PyMC distributions
+        prior_map["Gaussian"] = {
+            "pymc": "Normal",
+            "argmap": {"mu": "mu", "sigma": "sigma"},
+        }
+        prior_map["TruncatedGaussian"] = {
+            "pymc": "TruncatedNormal",
+            "argmap": {
+                "mu": "mu",
+                "sigma": "sigma",
+                "minimum": "lower",
+                "maximum": "upper",
+            },
+        }
+        prior_map["HalfGaussian"] = {"pymc": "HalfNormal", "argmap": {"sigma": "sigma"}}
+        prior_map["Uniform"] = {
+            "pymc": "Uniform",
+            "argmap": {"minimum": "lower", "maximum": "upper"},
+        }
+        prior_map["LogNormal"] = {
+            "pymc": "Lognormal",
+            "argmap": {"mu": "mu", "sigma": "sigma"},
+        }
+        prior_map["Exponential"] = {
+            "pymc": "Exponential",
+            "argmap": {"mu": "lam"},
+            "argtransform": {"mu": lambda mu: 1.0 / mu},
+        }
+        prior_map["StudentT"] = {
+            "pymc": "StudentT",
+            "argmap": {"df": "nu", "mu": "mu", "scale": "sigma"},
+        }
+        prior_map["Beta"] = {
+            "pymc": "Beta",
+            "argmap": {"alpha": "alpha", "beta": "beta"},
+        }
+        prior_map["Logistic"] = {
+            "pymc": "Logistic",
+            "argmap": {"mu": "mu", "scale": "s"},
+        }
+        prior_map["Cauchy"] = {
+            "pymc": "Cauchy",
+            "argmap": {"alpha": "alpha", "beta": "beta"},
+        }
+        prior_map["Gamma"] = {
+            "pymc": "Gamma",
+            "argmap": {"k": "alpha", "theta": "beta"},
+            "argtransform": {"theta": lambda theta: 1.0 / theta},
+        }
+        prior_map["ChiSquared"] = {"pymc": "ChiSquared", "argmap": {"nu": "nu"}}
+        prior_map["Interped"] = {
+            "pymc": "Interpolated",
+            "argmap": {"xx": "x_points", "yy": "pdf_points"},
+        }
+        prior_map["Normal"] = prior_map["Gaussian"]
+        prior_map["TruncatedNormal"] = prior_map["TruncatedGaussian"]
+        prior_map["HalfNormal"] = prior_map["HalfGaussian"]
+        prior_map["LogGaussian"] = prior_map["LogNormal"]
+        prior_map["Lorentzian"] = prior_map["Cauchy"]
+        prior_map["FromFile"] = prior_map["Interped"]
+
+        # GW specific priors
+        prior_map["UniformComovingVolume"] = prior_map["Interped"]
+
+        # internally defined mappings for bilby priors
+        prior_map["DeltaFunction"] = {"internal": self._deltafunction_prior}
+        prior_map["Sine"] = {"internal": self._sine_prior}
+        prior_map["Cosine"] = {"internal": self._cosine_prior}
+        prior_map["PowerLaw"] = {"internal": self._powerlaw_prior}
+        prior_map["LogUniform"] = {"internal": self._powerlaw_prior}
+        prior_map["MultivariateGaussian"] = {
+            "internal": self._multivariate_normal_prior
+        }
+        prior_map["MultivariateNormal"] = {"internal": self._multivariate_normal_prior}
+
+    def _deltafunction_prior(self, key, **kwargs):
+        """
+        Map the bilby delta function prior to a single value for PyMC.
+        """
+
+        # check prior is a DeltaFunction
+        if isinstance(self.priors[key], DeltaFunction):
+            return self.priors[key].peak
+        else:
+            raise ValueError(f"Prior for '{key}' is not a DeltaFunction")
+
+    def _sine_prior(self, key):
+        """
+        Map the bilby Sine prior to a PyMC style function
+        """
+
+        # check prior is a Sine
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
+        if isinstance(self.priors[key], Sine):
+
+            class PymcSine(pymc.Continuous):
+                def __init__(self, lower=0.0, upper=np.pi):
+                    if lower >= upper:
+                        raise ValueError("Lower bound is above upper bound!")
+
+                    # set the mode
+                    self.lower = lower = tt.as_tensor_variable(floatX(lower))
+                    self.upper = upper = tt.as_tensor_variable(floatX(upper))
+                    self.norm = tt.cos(lower) - tt.cos(upper)
+                    self.mean = (
+                        tt.sin(upper)
+                        + lower * tt.cos(lower)
+                        - tt.sin(lower)
+                        - upper * tt.cos(upper)
+                    ) / self.norm
+
+                    transform = pymc.distributions.transforms.interval(lower, upper)
+
+                    super(PymcSine, self).__init__(transform=transform)
+
+                def logp(self, value):
+                    upper = self.upper
+                    lower = self.lower
+                    return pymc.distributions.dist_math.bound(
+                        tt.log(tt.sin(value) / self.norm),
+                        lower <= value,
+                        value <= upper,
+                    )
+
+            return PymcSine(
+                key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
+            )
+        else:
+            raise ValueError(f"Prior for '{key}' is not a Sine")
+
+    def _cosine_prior(self, key):
+        """
+        Map the bilby Cosine prior to a PyMC style function
+        """
+
+        # check prior is a Cosine
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
+        if isinstance(self.priors[key], Cosine):
+
+            class PymcCosine(pymc.Continuous):
+                def __init__(self, lower=-np.pi / 2.0, upper=np.pi / 2.0):
+                    if lower >= upper:
+                        raise ValueError("Lower bound is above upper bound!")
+
+                    self.lower = lower = tt.as_tensor_variable(floatX(lower))
+                    self.upper = upper = tt.as_tensor_variable(floatX(upper))
+                    self.norm = tt.sin(upper) - tt.sin(lower)
+                    self.mean = (
+                        upper * tt.sin(upper)
+                        + tt.cos(upper)
+                        - lower * tt.sin(lower)
+                        - tt.cos(lower)
+                    ) / self.norm
+
+                    transform = pymc.distributions.transforms.interval(lower, upper)
+
+                    super(PymcCosine, self).__init__(transform=transform)
+
+                def logp(self, value):
+                    upper = self.upper
+                    lower = self.lower
+                    return pymc.distributions.dist_math.bound(
+                        tt.log(tt.cos(value) / self.norm),
+                        lower <= value,
+                        value <= upper,
+                    )
+
+            return PymcCosine(
+                key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
+            )
+        else:
+            raise ValueError(f"Prior for '{key}' is not a Cosine")
+
+    def _powerlaw_prior(self, key):
+        """
+        Map the bilby PowerLaw prior to a PyMC style function
+        """
+
+        # check prior is a PowerLaw
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
+        if isinstance(self.priors[key], PowerLaw):
+
+            # check power law is set
+            if not hasattr(self.priors[key], "alpha"):
+                raise AttributeError("No 'alpha' attribute set for PowerLaw prior")
+
+            if self.priors[key].alpha < -1.0:
+                # use Pareto distribution
+                palpha = -(1.0 + self.priors[key].alpha)
+
+                return pymc.Bound(pymc.Pareto, upper=self.priors[key].minimum)(
+                    key, alpha=palpha, m=self.priors[key].maximum
+                )
+            else:
+
+                class PymcPowerLaw(pymc.Continuous):
+                    def __init__(self, lower, upper, alpha, testval=1):
+                        falpha = alpha
+                        self.lower = lower = tt.as_tensor_variable(floatX(lower))
+                        self.upper = upper = tt.as_tensor_variable(floatX(upper))
+                        self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
+
+                        if falpha == -1:
+                            self.norm = 1.0 / (tt.log(self.upper / self.lower))
+                        else:
+                            beta = 1.0 + self.alpha
+                            self.norm = 1.0 / (
+                                beta
+                                * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta))
+                            )
+
+                        transform = pymc.distributions.transforms.interval(lower, upper)
+
+                        super(PymcPowerLaw, self).__init__(
+                            transform=transform, testval=testval
+                        )
+
+                    def logp(self, value):
+                        upper = self.upper
+                        lower = self.lower
+                        alpha = self.alpha
+
+                        return pymc.distributions.dist_math.bound(
+                            alpha * tt.log(value) + tt.log(self.norm),
+                            lower <= value,
+                            value <= upper,
+                        )
+
+                return PymcPowerLaw(
+                    key,
+                    lower=self.priors[key].minimum,
+                    upper=self.priors[key].maximum,
+                    alpha=self.priors[key].alpha,
+                )
+        else:
+            raise ValueError(f"Prior for '{key}' is not a Power Law")
+
+    def _multivariate_normal_prior(self, key):
+        """
+        Map the bilby MultivariateNormal prior to a PyMC style function.
+        """
+
+        # check prior is a PowerLaw
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
+        if isinstance(self.priors[key], MultivariateGaussian):
+            # get names of multivariate Gaussian parameters
+            mvpars = self.priors[key].mvg.names
+
+            # set the prior on multiple parameters if not present yet
+            if not np.all([p in self.multivariate_normal_sets for p in mvpars]):
+                mvg = self.priors[key].mvg
+
+                # get bounds
+                lower = [bound[0] for bound in mvg.bounds.values()]
+                upper = [bound[1] for bound in mvg.bounds.values()]
+
+                # test values required for mixture
+                testvals = []
+                for bound in mvg.bounds.values():
+                    if np.isinf(bound[0]) and np.isinf(bound[1]):
+                        testvals.append(0.0)
+                    elif np.isinf(bound[0]):
+                        testvals.append(bound[1] - 1.0)
+                    elif np.isinf(bound[1]):
+                        testvals.append(bound[0] + 1.0)
+                    else:
+                        # half-way between the two bounds
+                        testvals.append(bound[0] + (bound[1] - bound[0]) / 2.0)
+
+                # if bounds are at +/-infinity set to 100 sigmas as infinities
+                # cause problems for the Bound class
+                maxmu = np.max(mvg.mus, axis=0)
+                minmu = np.min(mvg.mus, axis=0)
+                maxsigma = np.max(mvg.sigmas, axis=0)
+                for i in range(len(mvpars)):
+                    if np.isinf(lower[i]):
+                        lower[i] = minmu[i] - 100.0 * maxsigma[i]
+                    if np.isinf(upper[i]):
+                        upper[i] = maxmu[i] + 100.0 * maxsigma[i]
+
+                # create a bounded MultivariateNormal distribution
+                BoundedMvN = pymc.Bound(pymc.MvNormal, lower=lower, upper=upper)
+
+                comp_dists = []  # list of any component modes
+                for i in range(mvg.nmodes):
+                    comp_dists.append(
+                        BoundedMvN(
+                            f"comp{i}",
+                            mu=mvg.mus[i],
+                            cov=mvg.covs[i],
+                            shape=len(mvpars),
+                        ).distribution
+                    )
+
+                # create a Mixture model
+                setname = f"mixture{self.multivariate_normal_num_sets}"
+                mix = pymc.Mixture(
+                    setname,
+                    w=mvg.weights,
+                    comp_dists=comp_dists,
+                    shape=len(mvpars),
+                    testval=testvals,
+                )
+
+                for i, p in enumerate(mvpars):
+                    self.multivariate_normal_sets[p] = {}
+                    self.multivariate_normal_sets[p]["prior"] = mix[i]
+                    self.multivariate_normal_sets[p]["set"] = setname
+                    self.multivariate_normal_sets[p]["index"] = i
+
+                self.multivariate_normal_num_sets += 1
+
+            # return required parameter
+            return self.multivariate_normal_sets[key]["prior"]
+
+        else:
+            raise ValueError(f"Prior for '{key}' is not a MultivariateGaussian")
+
+    def run_sampler(self):
+        # set the step method
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
+        if "step" in self._kwargs:
+            self.step_method = self._kwargs.pop("step")
+
+            # 'step' could be a dictionary of methods for different parameters,
+            # so check for this
+            if self.step_method is None:
+                pass
+            elif isinstance(self.step_method, dict):
+                for key in self.step_method:
+                    if key not in self._search_parameter_keys:
+                        raise ValueError(
+                            f"Setting a step method for an unknown parameter '{key}'"
+                        )
+                    else:
+                        # check if using a compound step (a list of step
+                        # methods for a particular parameter)
+                        if isinstance(self.step_method[key], list):
+                            sms = self.step_method[key]
+                        else:
+                            sms = [self.step_method[key]]
+                        for sm in sms:
+                            if sm.lower() not in step_methods:
+                                raise ValueError(
+                                    f"Using invalid step method '{self.step_method[key]}'"
+                                )
+            else:
+                # check if using a compound step (a list of step
+                # methods for a particular parameter)
+                if isinstance(self.step_method, list):
+                    sms = self.step_method
+                else:
+                    sms = [self.step_method]
+
+                for i in range(len(sms)):
+                    if sms[i].lower() not in step_methods:
+                        raise ValueError(f"Using invalid step method '{sms[i]}'")
+        else:
+            self.step_method = None
+
+        # initialise the PyMC model
+        self.pymc_model = pymc.Model()
+
+        # set the prior
+        self.set_prior()
+
+        # if a custom log_likelihood function requires a `sampler` argument
+        # then use that log_likelihood function, with the assumption that it
+        # takes in a Pymc Sampler, with a pymc_model attribute, and defines
+        # the likelihood within that context manager
+        likeargs = infer_args_from_method(self.likelihood.log_likelihood)
+        if "sampler" in likeargs:
+            self.likelihood.log_likelihood(sampler=self)
+        else:
+            # set the likelihood function from predefined functions
+            self.set_likelihood()
+
+        # get the step method keyword arguments
+        step_kwargs = self.kwargs.pop("step_kwargs")
+        if step_kwargs is not None:
+            # remove all individual default step kwargs if passed together using
+            # step_kwargs keywords
+            for key in self.default_step_kwargs:
+                self.kwargs.pop(key)
+        else:
+            # remove any None default step keywords and place others in step_kwargs
+            step_kwargs = {}
+            for key in self.default_step_kwargs:
+                if self.kwargs[key] is None:
+                    self.kwargs.pop(key)
+                else:
+                    step_kwargs[key] = self.kwargs.pop(key)
+
+        nuts_kwargs = self.kwargs.pop("nuts_kwargs")
+        if nuts_kwargs is not None:
+            # remove all individual default nuts kwargs if passed together using
+            # nuts_kwargs keywords
+            for key in self.default_nuts_kwargs:
+                self.kwargs.pop(key)
+        else:
+            # remove any None default nuts keywords and place others in nut_kwargs
+            nuts_kwargs = {}
+            for key in self.default_nuts_kwargs:
+                if self.kwargs[key] is None:
+                    self.kwargs.pop(key)
+                else:
+                    nuts_kwargs[key] = self.kwargs.pop(key)
+        methodslist = []
+
+        # set the step method
+        if isinstance(self.step_method, dict):
+            # create list of step methods (any not given will default to NUTS)
+            self.kwargs["step"] = []
+            with self.pymc_model:
+                for key in self.step_method:
+                    # check for a compound step list
+                    if isinstance(self.step_method[key], list):
+                        for sms in self.step_method[key]:
+                            curmethod = sms.lower()
+                            methodslist.append(curmethod)
+                            nuts_kwargs = self._create_nuts_kwargs(
+                                curmethod,
+                                key,
+                                nuts_kwargs,
+                                pymc,
+                                step_kwargs,
+                                step_methods,
+                            )
+                    else:
+                        curmethod = self.step_method[key].lower()
+                        methodslist.append(curmethod)
+                        nuts_kwargs = self._create_nuts_kwargs(
+                            curmethod,
+                            key,
+                            nuts_kwargs,
+                            pymc,
+                            step_kwargs,
+                            step_methods,
+                        )
+        else:
+            with self.pymc_model:
+                # check for a compound step list
+                if isinstance(self.step_method, list):
+                    compound = []
+                    for sms in self.step_method:
+                        curmethod = sms.lower()
+                        methodslist.append(curmethod)
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(
+                            curmethod, nuts_kwargs, step_kwargs
+                        )
+                        compound.append(pymc.__dict__[step_methods[curmethod]](**args))
+                        self.kwargs["step"] = compound
+                else:
+                    self.kwargs["step"] = None
+                    if self.step_method is not None:
+                        curmethod = self.step_method.lower()
+                        methodslist.append(curmethod)
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(
+                            curmethod, nuts_kwargs, step_kwargs
+                        )
+                        self.kwargs["step"] = pymc.__dict__[step_methods[curmethod]](
+                            **args
+                        )
+                    else:
+                        # re-add step_kwargs if no step methods are set
+                        if len(step_kwargs) > 0 and StrictVersion(
+                            pymc.__version__
+                        ) < StrictVersion("3.7"):
+                            self.kwargs["step_kwargs"] = step_kwargs
+
+        # check whether only NUTS step method has been assigned
+        if np.all([sm.lower() == "nuts" for sm in methodslist]):
+            # in this case we can let PyMC autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
+            self.kwargs["step"] = None
+
+            if len(nuts_kwargs) > 0 and StrictVersion(pymc.__version__) < StrictVersion(
+                "3.7"
+            ):
+                self.kwargs["nuts_kwargs"] = nuts_kwargs
+            elif len(nuts_kwargs) > 0:
+                # add NUTS kwargs to standard kwargs
+                self.kwargs.update(nuts_kwargs)
+
+        with self.pymc_model:
+            # perform the sampling
+            trace = pymc.sample(**self.kwargs)
+
+        posterior = trace.posterior.to_dataframe().reset_index()
+        self.result.samples = posterior[self.search_parameter_keys]
+        self.result.log_likelihood_evaluations = np.sum(
+            trace.log_likelihood.likelihood.values, axis=-1
+        ).flatten()
+        self.result.sampler_output = np.nan
+        self.calculate_autocorrelation(self.result.samples)
+        self.result.log_evidence = np.nan
+        self.result.log_evidence_err = np.nan
+        self.calc_likelihood_count()
+        return self.result
+
+    def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs):
+        if curmethod == "nuts":
+            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
+        else:
+            args = step_kwargs.get(curmethod, {})
+        return args, nuts_kwargs
+
+    def _create_nuts_kwargs(
+        self, curmethod, key, nuts_kwargs, pymc, step_kwargs, step_methods
+    ):
+        if curmethod == "nuts":
+            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
+        else:
+            if step_kwargs is not None:
+                args = step_kwargs.get(curmethod, {})
+            else:
+                args = {}
+        self.kwargs["step"].append(
+            pymc.__dict__[step_methods[curmethod]](vars=[self.pymc_priors[key]], **args)
+        )
+        return nuts_kwargs
+
+    @staticmethod
+    def _get_nuts_args(nuts_kwargs, step_kwargs):
+        if nuts_kwargs is not None:
+            args = nuts_kwargs
+        elif step_kwargs is not None:
+            args = step_kwargs.pop("nuts", {})
+            # add values into nuts_kwargs
+            nuts_kwargs = args
+        else:
+            args = {}
+        return args, nuts_kwargs
+
+    def _pymc_version(self):
+        pymc, _, _ = self._import_external_sampler()
+        return pymc.__version__
+
+    def set_prior(self):
+        """
+        Set the PyMC prior distributions.
+        """
+
+        self.setup_prior_mapping()
+
+        self.pymc_priors = dict()
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+
+        # initialise a dictionary of multivariate Gaussian parameters
+        self.multivariate_normal_sets = {}
+        self.multivariate_normal_num_sets = 0
+
+        # set the parameter prior distributions (in the model context manager)
+        with self.pymc_model:
+            for key in self.priors:
+                # if the prior contains ln_prob method that takes a 'sampler' argument
+                # then try using that
+                lnprobargs = infer_args_from_method(self.priors[key].ln_prob)
+                if "sampler" in lnprobargs:
+                    try:
+                        self.pymc_priors[key] = self.priors[key].ln_prob(sampler=self)
+                    except RuntimeError:
+                        raise RuntimeError((f"Problem setting PyMC prior for '{key}'"))
+                else:
+                    # use Prior distribution name
+                    distname = self.priors[key].__class__.__name__
+
+                    if distname in self.prior_map:
+                        # check if we have a predefined PyMC distribution
+                        if (
+                            "pymc" in self.prior_map[distname]
+                            and "argmap" in self.prior_map[distname]
+                        ):
+                            # check the required arguments for the PyMC distribution
+                            pymcdistname = self.prior_map[distname]["pymc"]
+
+                            if pymcdistname not in pymc.__dict__:
+                                raise ValueError(
+                                    f"Prior '{pymcdistname}' is not a known PyMC distribution."
+                                )
+
+                            reqargs = infer_args_from_method(
+                                pymc.__dict__[pymcdistname].dist
+                            )
+
+                            # set keyword arguments
+                            priorkwargs = {}
+                            for (targ, parg) in self.prior_map[distname][
+                                "argmap"
+                            ].items():
+                                if hasattr(self.priors[key], targ):
+                                    if parg in reqargs:
+                                        if "argtransform" in self.prior_map[distname]:
+                                            if (
+                                                targ
+                                                in self.prior_map[distname][
+                                                    "argtransform"
+                                                ]
+                                            ):
+                                                tfunc = self.prior_map[distname][
+                                                    "argtransform"
+                                                ][targ]
+                                            else:
+
+                                                def tfunc(x):
+                                                    return x
+
+                                        else:
+
+                                            def tfunc(x):
+                                                return x
+
+                                        priorkwargs[parg] = tfunc(
+                                            getattr(self.priors[key], targ)
+                                        )
+                                    else:
+                                        raise ValueError(f"Unknown argument {parg}")
+                                else:
+                                    if parg in reqargs:
+                                        priorkwargs[parg] = None
+                            self.pymc_priors[key] = pymc.__dict__[pymcdistname](
+                                key, **priorkwargs
+                            )
+                        elif "internal" in self.prior_map[distname]:
+                            self.pymc_priors[key] = self.prior_map[distname][
+                                "internal"
+                            ](key)
+                        else:
+                            raise ValueError(
+                                f"Prior '{distname}' is not a known distribution."
+                            )
+                    else:
+                        raise ValueError(
+                            f"Prior '{distname}' is not a known distribution."
+                        )
+
+    def set_likelihood(self):
+        """
+        Convert any bilby likelihoods to PyMC distributions.
+        """
+
+        # create aesara Op for the log likelihood if not using a predefined model
+        pymc, STEP_METHODS, floatX = self._import_external_sampler()
+        aesara, tt, as_op = self._import_aesara()
+
+        class LogLike(tt.Op):
+
+            itypes = [tt.dvector]
+            otypes = [tt.dscalar]
+
+            def __init__(self, parameters, loglike, priors):
+                self.parameters = parameters
+                self.likelihood = loglike
+                self.priors = priors
+
+                # set the fixed parameters
+                for key in self.priors.keys():
+                    if isinstance(self.priors[key], float):
+                        self.likelihood.parameters[key] = self.priors[key]
+
+                self.logpgrad = LogLikeGrad(
+                    self.parameters, self.likelihood, self.priors
+                )
+
+            def perform(self, node, inputs, outputs):
+                (theta,) = inputs
+                for i, key in enumerate(self.parameters):
+                    self.likelihood.parameters[key] = theta[i]
+
+                outputs[0][0] = np.array(self.likelihood.log_likelihood())
+
+            def grad(self, inputs, g):
+                (theta,) = inputs
+                return [g[0] * self.logpgrad(theta)]
+
+        # create aesara Op for calculating the gradient of the log likelihood
+        class LogLikeGrad(tt.Op):
+
+            itypes = [tt.dvector]
+            otypes = [tt.dvector]
+
+            def __init__(self, parameters, loglike, priors):
+                self.parameters = parameters
+                self.Nparams = len(parameters)
+                self.likelihood = loglike
+                self.priors = priors
+
+                # set the fixed parameters
+                for key in self.priors.keys():
+                    if isinstance(self.priors[key], float):
+                        self.likelihood.parameters[key] = self.priors[key]
+
+            def perform(self, node, inputs, outputs):
+                (theta,) = inputs
+
+                # define version of likelihood function to pass to derivative function
+                def lnlike(values):
+                    for i, key in enumerate(self.parameters):
+                        self.likelihood.parameters[key] = values[i]
+                    return self.likelihood.log_likelihood()
+
+                # calculate gradients
+                grads = derivatives(
+                    theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2
+                )
+
+                outputs[0][0] = grads
+
+        with self.pymc_model:
+            #  check if it is a predefined likelhood function
+            if isinstance(self.likelihood, GaussianLikelihood):
+                # check required attributes exist
+                if (
+                    not hasattr(self.likelihood, "sigma")
+                    or not hasattr(self.likelihood, "x")
+                    or not hasattr(self.likelihood, "y")
+                ):
+                    raise ValueError(
+                        "Gaussian Likelihood does not have all the correct attributes!"
+                    )
+
+                if "sigma" in self.pymc_priors:
+                    # if sigma is suppled use that value
+                    if self.likelihood.sigma is None:
+                        self.likelihood.sigma = self.pymc_priors.pop("sigma")
+                    else:
+                        del self.pymc_priors["sigma"]
+
+                for key in self.pymc_priors:
+                    if key not in self.likelihood.function_keys:
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
+
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
+
+                # set the distribution
+                pymc.Normal(
+                    "likelihood",
+                    mu=model,
+                    sigma=self.likelihood.sigma,
+                    observed=self.likelihood.y,
+                )
+            elif isinstance(self.likelihood, PoissonLikelihood):
+                # check required attributes exist
+                if not hasattr(self.likelihood, "x") or not hasattr(
+                    self.likelihood, "y"
+                ):
+                    raise ValueError(
+                        "Poisson Likelihood does not have all the correct attributes!"
+                    )
+
+                for key in self.pymc_priors:
+                    if key not in self.likelihood.function_keys:
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
+
+                # get rate function
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
+
+                # set the distribution
+                pymc.Poisson("likelihood", mu=model, observed=self.likelihood.y)
+            elif isinstance(self.likelihood, ExponentialLikelihood):
+                # check required attributes exist
+                if not hasattr(self.likelihood, "x") or not hasattr(
+                    self.likelihood, "y"
+                ):
+                    raise ValueError(
+                        "Exponential Likelihood does not have all the correct attributes!"
+                    )
+
+                for key in self.pymc_priors:
+                    if key not in self.likelihood.function_keys:
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
+
+                # get mean function
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
+
+                # set the distribution
+                pymc.Exponential(
+                    "likelihood", lam=1.0 / model, observed=self.likelihood.y
+                )
+            elif isinstance(self.likelihood, StudentTLikelihood):
+                # check required attributes exist
+                if (
+                    not hasattr(self.likelihood, "x")
+                    or not hasattr(self.likelihood, "y")
+                    or not hasattr(self.likelihood, "nu")
+                    or not hasattr(self.likelihood, "sigma")
+                ):
+                    raise ValueError(
+                        "StudentT Likelihood does not have all the correct attributes!"
+                    )
+
+                if "nu" in self.pymc_priors:
+                    # if nu is suppled use that value
+                    if self.likelihood.nu is None:
+                        self.likelihood.nu = self.pymc_priors.pop("nu")
+                    else:
+                        del self.pymc_priors["nu"]
+
+                for key in self.pymc_priors:
+                    if key not in self.likelihood.function_keys:
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
+
+                model = self.likelihood.func(self.likelihood.x, **self.pymc_priors)
+
+                # set the distribution
+                pymc.StudentT(
+                    "likelihood",
+                    nu=self.likelihood.nu,
+                    mu=model,
+                    sigma=self.likelihood.sigma,
+                    observed=self.likelihood.y,
+                )
+            elif isinstance(
+                self.likelihood,
+                (GravitationalWaveTransient, BasicGravitationalWaveTransient),
+            ):
+                # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
+                logl = LogLike(
+                    self._search_parameter_keys, self.likelihood, self.pymc_priors
+                )
+
+                parameters = dict()
+                for key in self._search_parameter_keys:
+                    try:
+                        parameters[key] = self.pymc_priors[key]
+                    except KeyError:
+                        raise KeyError(
+                            f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood"
+                        )
+
+                # convert to aesara tensor variable
+                values = tt.as_tensor_variable(list(parameters.values()))
+
+                pymc.DensityDist(
+                    "likelihood", lambda v: logl(v), observed={"v": values}
+                )
+            else:
+                raise ValueError("Unknown likelihood has been provided")
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
deleted file mode 100644
index 4ff4b232adcfb3432a1f8ce2973fb24fc4778d7b..0000000000000000000000000000000000000000
--- a/bilby/core/sampler/pymc3.py
+++ /dev/null
@@ -1,835 +0,0 @@
-from distutils.version import StrictVersion
-
-import numpy as np
-
-from ..utils import derivatives, infer_args_from_method
-from ..prior import DeltaFunction, Sine, Cosine, PowerLaw, MultivariateGaussian
-from .base_sampler import MCMCSampler
-from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikelihood, \
-    StudentTLikelihood
-from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
-
-
-class Pymc3(MCMCSampler):
-    """ bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
-
-    All keyword arguments (i.e., the kwargs) passed to `run_sampler` will be
-    propapated to `pymc3.sample` where appropriate, see documentation for that
-    class for further help. Under Other Parameters, we list commonly used
-    kwargs and the bilby, or where appropriate, PyMC3 defaults.
-
-    Parameters
-    ==========
-    draws: int, (1000)
-        The number of sample draws from the posterior per chain.
-    chains: int, (2)
-        The number of independent MCMC chains to run.
-    cores: int, (1)
-        The number of CPU cores to use.
-    tune: int, (500)
-        The number of tuning (or burn-in) samples per chain.
-    discard_tuned_samples: bool, True
-        Set whether to automatically discard the tuning samples from the final
-        chains.
-    step: str, dict
-        Provide a step method name, or dictionary of step method names keyed to
-        particular variable names (these are case insensitive). If passing a
-        dictionary of methods, the values keyed on particular variables can be
-        lists of methods to form compound steps. If no method is provided for
-        any particular variable then PyMC3 will automatically decide upon a
-        default, with the first option being the NUTS sampler. The currently
-        allowed methods are 'NUTS', 'HamiltonianMC', 'Metropolis',
-        'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and
-        'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step
-        method function itself here as it is outside of the model context
-        manager.
-    step_kwargs: dict
-        Options for steps methods other than NUTS. The dictionary is keyed on
-        lowercase step method names with values being dictionaries of keywords
-        for the given step method.
-
-    """
-
-    default_kwargs = dict(
-        draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0,
-        chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None,
-        discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None,
-        step_kwargs=None,
-    )
-
-    default_nuts_kwargs = dict(
-        target_accept=None, max_treedepth=None, step_scale=None, Emax=None,
-        gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None,
-        scaling=None, is_cov=None, potential=None,
-    )
-
-    default_kwargs.update(default_nuts_kwargs)
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False,
-                 skip_import_verification=False, **kwargs):
-        # add default step kwargs
-        _, STEP_METHODS, _ = self._import_external_sampler()
-        self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
-        self.default_kwargs.update(self.default_step_kwargs)
-
-        super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                    use_ratio=use_ratio, plot=plot,
-                                    skip_import_verification=skip_import_verification, **kwargs)
-        self.draws = self._kwargs['draws']
-        self.chains = self._kwargs['chains']
-
-    @staticmethod
-    def _import_external_sampler():
-        import pymc3
-        from pymc3.sampling import STEP_METHODS
-        from pymc3.theanof import floatX
-        return pymc3, STEP_METHODS, floatX
-
-    @staticmethod
-    def _import_theano():
-        import theano  # noqa
-        import theano.tensor as tt
-        from theano.compile.ops import as_op  # noqa
-        return theano, tt, as_op
-
-    def _verify_parameters(self):
-        """
-        Change `_verify_parameters()` to just pass, i.e., don't try and
-        evaluate the likelihood for PyMC3.
-        """
-        pass
-
-    def _verify_use_ratio(self):
-        """
-        Change `_verify_use_ratio() to just pass.
-        """
-        pass
-
-    def setup_prior_mapping(self):
-        """
-        Set the mapping between predefined bilby priors and the equivalent
-        PyMC3 distributions.
-        """
-
-        prior_map = {}
-        self.prior_map = prior_map
-
-        # predefined PyMC3 distributions
-        prior_map['Gaussian'] = {
-            'pymc3': 'Normal',
-            'argmap': {'mu': 'mu', 'sigma': 'sd'}}
-        prior_map['TruncatedGaussian'] = {
-            'pymc3': 'TruncatedNormal',
-            'argmap': {'mu': 'mu',
-                       'sigma': 'sd',
-                       'minimum': 'lower',
-                       'maximum': 'upper'}}
-        prior_map['HalfGaussian'] = {
-            'pymc3': 'HalfNormal',
-            'argmap': {'sigma': 'sd'}}
-        prior_map['Uniform'] = {
-            'pymc3': 'Uniform',
-            'argmap': {'minimum': 'lower',
-                       'maximum': 'upper'}}
-        prior_map['LogNormal'] = {
-            'pymc3': 'Lognormal',
-            'argmap': {'mu': 'mu',
-                       'sigma': 'sd'}}
-        prior_map['Exponential'] = {
-            'pymc3': 'Exponential',
-            'argmap': {'mu': 'lam'},
-            'argtransform': {'mu': lambda mu: 1. / mu}}
-        prior_map['StudentT'] = {
-            'pymc3': 'StudentT',
-            'argmap': {'df': 'nu',
-                       'mu': 'mu',
-                       'scale': 'sd'}}
-        prior_map['Beta'] = {
-            'pymc3': 'Beta',
-            'argmap': {'alpha': 'alpha',
-                       'beta': 'beta'}}
-        prior_map['Logistic'] = {
-            'pymc3': 'Logistic',
-            'argmap': {'mu': 'mu',
-                       'scale': 's'}}
-        prior_map['Cauchy'] = {
-            'pymc3': 'Cauchy',
-            'argmap': {'alpha': 'alpha',
-                       'beta': 'beta'}}
-        prior_map['Gamma'] = {
-            'pymc3': 'Gamma',
-            'argmap': {'k': 'alpha',
-                       'theta': 'beta'},
-            'argtransform': {'theta': lambda theta: 1. / theta}}
-        prior_map['ChiSquared'] = {
-            'pymc3': 'ChiSquared',
-            'argmap': {'nu': 'nu'}}
-        prior_map['Interped'] = {
-            'pymc3': 'Interpolated',
-            'argmap': {'xx': 'x_points',
-                       'yy': 'pdf_points'}}
-        prior_map['Normal'] = prior_map['Gaussian']
-        prior_map['TruncatedNormal'] = prior_map['TruncatedGaussian']
-        prior_map['HalfNormal'] = prior_map['HalfGaussian']
-        prior_map['LogGaussian'] = prior_map['LogNormal']
-        prior_map['Lorentzian'] = prior_map['Cauchy']
-        prior_map['FromFile'] = prior_map['Interped']
-
-        # GW specific priors
-        prior_map['UniformComovingVolume'] = prior_map['Interped']
-
-        # internally defined mappings for bilby priors
-        prior_map['DeltaFunction'] = {'internal': self._deltafunction_prior}
-        prior_map['Sine'] = {'internal': self._sine_prior}
-        prior_map['Cosine'] = {'internal': self._cosine_prior}
-        prior_map['PowerLaw'] = {'internal': self._powerlaw_prior}
-        prior_map['LogUniform'] = {'internal': self._powerlaw_prior}
-        prior_map['MultivariateGaussian'] = {'internal': self._multivariate_normal_prior}
-        prior_map['MultivariateNormal'] = {'internal': self._multivariate_normal_prior}
-
-    def _deltafunction_prior(self, key, **kwargs):
-        """
-        Map the bilby delta function prior to a single value for PyMC3.
-        """
-
-        # check prior is a DeltaFunction
-        if isinstance(self.priors[key], DeltaFunction):
-            return self.priors[key].peak
-        else:
-            raise ValueError("Prior for '{}' is not a DeltaFunction".format(key))
-
-    def _sine_prior(self, key):
-        """
-        Map the bilby Sine prior to a PyMC3 style function
-        """
-
-        # check prior is a Sine
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
-        if isinstance(self.priors[key], Sine):
-
-            class Pymc3Sine(pymc3.Continuous):
-                def __init__(self, lower=0., upper=np.pi):
-                    if lower >= upper:
-                        raise ValueError("Lower bound is above upper bound!")
-
-                    # set the mode
-                    self.lower = lower = tt.as_tensor_variable(floatX(lower))
-                    self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                    self.norm = (tt.cos(lower) - tt.cos(upper))
-                    self.mean = \
-                        (tt.sin(upper) + lower * tt.cos(lower) -
-                         tt.sin(lower) - upper * tt.cos(upper)) / self.norm
-
-                    transform = pymc3.distributions.transforms.interval(lower,
-                                                                        upper)
-
-                    super(Pymc3Sine, self).__init__(transform=transform)
-
-                def logp(self, value):
-                    upper = self.upper
-                    lower = self.lower
-                    return pymc3.distributions.dist_math.bound(
-                        tt.log(tt.sin(value) / self.norm),
-                        lower <= value, value <= upper)
-
-            return Pymc3Sine(key, lower=self.priors[key].minimum,
-                             upper=self.priors[key].maximum)
-        else:
-            raise ValueError("Prior for '{}' is not a Sine".format(key))
-
-    def _cosine_prior(self, key):
-        """
-        Map the bilby Cosine prior to a PyMC3 style function
-        """
-
-        # check prior is a Cosine
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
-        if isinstance(self.priors[key], Cosine):
-
-            class Pymc3Cosine(pymc3.Continuous):
-                def __init__(self, lower=-np.pi / 2., upper=np.pi / 2.):
-                    if lower >= upper:
-                        raise ValueError("Lower bound is above upper bound!")
-
-                    self.lower = lower = tt.as_tensor_variable(floatX(lower))
-                    self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                    self.norm = (tt.sin(upper) - tt.sin(lower))
-                    self.mean = \
-                        (upper * tt.sin(upper) + tt.cos(upper) -
-                         lower * tt.sin(lower) - tt.cos(lower)) / self.norm
-
-                    transform = pymc3.distributions.transforms.interval(lower,
-                                                                        upper)
-
-                    super(Pymc3Cosine, self).__init__(transform=transform)
-
-                def logp(self, value):
-                    upper = self.upper
-                    lower = self.lower
-                    return pymc3.distributions.dist_math.bound(
-                        tt.log(tt.cos(value) / self.norm),
-                        lower <= value, value <= upper)
-
-            return Pymc3Cosine(key, lower=self.priors[key].minimum,
-                               upper=self.priors[key].maximum)
-        else:
-            raise ValueError("Prior for '{}' is not a Cosine".format(key))
-
-    def _powerlaw_prior(self, key):
-        """
-        Map the bilby PowerLaw prior to a PyMC3 style function
-        """
-
-        # check prior is a PowerLaw
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
-        if isinstance(self.priors[key], PowerLaw):
-
-            # check power law is set
-            if not hasattr(self.priors[key], 'alpha'):
-                raise AttributeError("No 'alpha' attribute set for PowerLaw prior")
-
-            if self.priors[key].alpha < -1.:
-                # use Pareto distribution
-                palpha = -(1. + self.priors[key].alpha)
-
-                return pymc3.Bound(
-                    pymc3.Pareto, upper=self.priors[key].minimum)(
-                    key, alpha=palpha, m=self.priors[key].maximum)
-            else:
-                class Pymc3PowerLaw(pymc3.Continuous):
-                    def __init__(self, lower, upper, alpha, testval=1):
-                        falpha = alpha
-                        self.lower = lower = tt.as_tensor_variable(floatX(lower))
-                        self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                        self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
-
-                        if falpha == -1:
-                            self.norm = 1. / (tt.log(self.upper / self.lower))
-                        else:
-                            beta = (1. + self.alpha)
-                            self.norm = 1. / (beta * (tt.pow(self.upper, beta) -
-                                                      tt.pow(self.lower, beta)))
-
-                        transform = pymc3.distributions.transforms.interval(
-                            lower, upper)
-
-                        super(Pymc3PowerLaw, self).__init__(
-                            transform=transform, testval=testval)
-
-                    def logp(self, value):
-                        upper = self.upper
-                        lower = self.lower
-                        alpha = self.alpha
-
-                        return pymc3.distributions.dist_math.bound(
-                            alpha * tt.log(value) + tt.log(self.norm),
-                            lower <= value, value <= upper)
-
-                return Pymc3PowerLaw(key, lower=self.priors[key].minimum,
-                                     upper=self.priors[key].maximum,
-                                     alpha=self.priors[key].alpha)
-        else:
-            raise ValueError("Prior for '{}' is not a Power Law".format(key))
-
-    def _multivariate_normal_prior(self, key):
-        """
-        Map the bilby MultivariateNormal prior to a PyMC3 style function.
-        """
-
-        # check prior is a PowerLaw
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
-        if isinstance(self.priors[key], MultivariateGaussian):
-            # get names of multivariate Gaussian parameters
-            mvpars = self.priors[key].mvg.names
-
-            # set the prior on multiple parameters if not present yet
-            if not np.all([p in self.multivariate_normal_sets for p in mvpars]):
-                mvg = self.priors[key].mvg
-
-                # get bounds
-                lower = [bound[0] for bound in mvg.bounds.values()]
-                upper = [bound[1] for bound in mvg.bounds.values()]
-
-                # test values required for mixture
-                testvals = []
-                for bound in mvg.bounds.values():
-                    if np.isinf(bound[0]) and np.isinf(bound[1]):
-                        testvals.append(0.)
-                    elif np.isinf(bound[0]):
-                        testvals.append(bound[1] - 1.)
-                    elif np.isinf(bound[1]):
-                        testvals.append(bound[0] + 1.)
-                    else:
-                        # half-way between the two bounds
-                        testvals.append(bound[0] + (bound[1] - bound[0]) / 2.)
-
-                # if bounds are at +/-infinity set to 100 sigmas as infinities
-                # cause problems for the Bound class
-                maxmu = np.max(mvg.mus, axis=0)
-                minmu = np.min(mvg.mus, axis=0)
-                maxsigma = np.max(mvg.sigmas, axis=0)
-                for i in range(len(mvpars)):
-                    if np.isinf(lower[i]):
-                        lower[i] = minmu[i] - 100. * maxsigma[i]
-                    if np.isinf(upper[i]):
-                        upper[i] = maxmu[i] + 100. * maxsigma[i]
-
-                # create a bounded MultivariateNormal distribution
-                BoundedMvN = pymc3.Bound(pymc3.MvNormal, lower=lower, upper=upper)
-
-                comp_dists = []  # list of any component modes
-                for i in range(mvg.nmodes):
-                    comp_dists.append(BoundedMvN('comp{}'.format(i), mu=mvg.mus[i],
-                                                 cov=mvg.covs[i],
-                                                 shape=len(mvpars)).distribution)
-
-                # create a Mixture model
-                setname = 'mixture{}'.format(self.multivariate_normal_num_sets)
-                mix = pymc3.Mixture(setname, w=mvg.weights, comp_dists=comp_dists,
-                                    shape=len(mvpars), testval=testvals)
-
-                for i, p in enumerate(mvpars):
-                    self.multivariate_normal_sets[p] = {}
-                    self.multivariate_normal_sets[p]['prior'] = mix[i]
-                    self.multivariate_normal_sets[p]['set'] = setname
-                    self.multivariate_normal_sets[p]['index'] = i
-
-                self.multivariate_normal_num_sets += 1
-
-            # return required parameter
-            return self.multivariate_normal_sets[key]['prior']
-
-        else:
-            raise ValueError("Prior for '{}' is not a MultivariateGaussian".format(key))
-
-    def run_sampler(self):
-        # set the step method
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
-        if 'step' in self._kwargs:
-            self.step_method = self._kwargs.pop('step')
-
-            # 'step' could be a dictionary of methods for different parameters,
-            # so check for this
-            if self.step_method is None:
-                pass
-            elif isinstance(self.step_method, dict):
-                for key in self.step_method:
-                    if key not in self._search_parameter_keys:
-                        raise ValueError("Setting a step method for an unknown parameter '{}'".format(key))
-                    else:
-                        # check if using a compound step (a list of step
-                        # methods for a particular parameter)
-                        if isinstance(self.step_method[key], list):
-                            sms = self.step_method[key]
-                        else:
-                            sms = [self.step_method[key]]
-                        for sm in sms:
-                            if sm.lower() not in step_methods:
-                                raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
-            else:
-                # check if using a compound step (a list of step
-                # methods for a particular parameter)
-                if isinstance(self.step_method, list):
-                    sms = self.step_method
-                else:
-                    sms = [self.step_method]
-
-                for i in range(len(sms)):
-                    if sms[i].lower() not in step_methods:
-                        raise ValueError("Using invalid step method '{}'".format(sms[i]))
-        else:
-            self.step_method = None
-
-        # initialise the PyMC3 model
-        self.pymc3_model = pymc3.Model()
-
-        # set the prior
-        self.set_prior()
-
-        # if a custom log_likelihood function requires a `sampler` argument
-        # then use that log_likelihood function, with the assumption that it
-        # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
-        # the likelihood within that context manager
-        likeargs = infer_args_from_method(self.likelihood.log_likelihood)
-        if 'sampler' in likeargs:
-            self.likelihood.log_likelihood(sampler=self)
-        else:
-            # set the likelihood function from predefined functions
-            self.set_likelihood()
-
-        # get the step method keyword arguments
-        step_kwargs = self.kwargs.pop("step_kwargs")
-        if step_kwargs is not None:
-            # remove all individual default step kwargs if passed together using
-            # step_kwargs keywords
-            for key in self.default_step_kwargs:
-                self.kwargs.pop(key)
-        else:
-            # remove any None default step keywords and place others in step_kwargs
-            step_kwargs = {}
-            for key in self.default_step_kwargs:
-                if self.kwargs[key] is None:
-                    self.kwargs.pop(key)
-                else:
-                    step_kwargs[key] = self.kwargs.pop(key)
-
-        nuts_kwargs = self.kwargs.pop("nuts_kwargs")
-        if nuts_kwargs is not None:
-            # remove all individual default nuts kwargs if passed together using
-            # nuts_kwargs keywords
-            for key in self.default_nuts_kwargs:
-                self.kwargs.pop(key)
-        else:
-            # remove any None default nuts keywords and place others in nut_kwargs
-            nuts_kwargs = {}
-            for key in self.default_nuts_kwargs:
-                if self.kwargs[key] is None:
-                    self.kwargs.pop(key)
-                else:
-                    nuts_kwargs[key] = self.kwargs.pop(key)
-        methodslist = []
-
-        # set the step method
-        if isinstance(self.step_method, dict):
-            # create list of step methods (any not given will default to NUTS)
-            self.kwargs['step'] = []
-            with self.pymc3_model:
-                for key in self.step_method:
-                    # check for a compound step list
-                    if isinstance(self.step_method[key], list):
-                        for sms in self.step_method[key]:
-                            curmethod = sms.lower()
-                            methodslist.append(curmethod)
-                            nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
-                                                                   step_methods)
-                    else:
-                        curmethod = self.step_method[key].lower()
-                        methodslist.append(curmethod)
-                        nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
-                                                               step_methods)
-        else:
-            with self.pymc3_model:
-                # check for a compound step list
-                if isinstance(self.step_method, list):
-                    compound = []
-                    for sms in self.step_method:
-                        curmethod = sms.lower()
-                        methodslist.append(curmethod)
-                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
-                        compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
-                        self.kwargs['step'] = compound
-                else:
-                    self.kwargs['step'] = None
-                    if self.step_method is not None:
-                        curmethod = self.step_method.lower()
-                        methodslist.append(curmethod)
-                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
-                        self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
-                    else:
-                        # re-add step_kwargs if no step methods are set
-                        if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
-                            self.kwargs['step_kwargs'] = step_kwargs
-
-        # check whether only NUTS step method has been assigned
-        if np.all([sm.lower() == 'nuts' for sm in methodslist]):
-            # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
-            self.kwargs['step'] = None
-
-            if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
-                self.kwargs['nuts_kwargs'] = nuts_kwargs
-            elif len(nuts_kwargs) > 0:
-                # add NUTS kwargs to standard kwargs
-                self.kwargs.update(nuts_kwargs)
-
-        with self.pymc3_model:
-            # perform the sampling
-            trace = pymc3.sample(**self.kwargs, return_inferencedata=True)
-
-        posterior = trace.posterior.to_dataframe().reset_index()
-        self.result.samples = posterior[self.search_parameter_keys]
-        self.result.log_likelihood_evaluations = np.sum(
-            trace.log_likelihood.likelihood.values, axis=-1
-        ).flatten()
-        self.result.sampler_output = np.nan
-        self.calculate_autocorrelation(self.result.samples)
-        self.result.log_evidence = np.nan
-        self.result.log_evidence_err = np.nan
-        self.calc_likelihood_count()
-        return self.result
-
-    def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs):
-        if curmethod == 'nuts':
-            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
-        else:
-            args = step_kwargs.get(curmethod, {})
-        return args, nuts_kwargs
-
-    def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods):
-        if curmethod == 'nuts':
-            args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
-        else:
-            if step_kwargs is not None:
-                args = step_kwargs.get(curmethod, {})
-            else:
-                args = {}
-        self.kwargs['step'].append(
-            pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
-        return nuts_kwargs
-
-    @staticmethod
-    def _get_nuts_args(nuts_kwargs, step_kwargs):
-        if nuts_kwargs is not None:
-            args = nuts_kwargs
-        elif step_kwargs is not None:
-            args = step_kwargs.pop('nuts', {})
-            # add values into nuts_kwargs
-            nuts_kwargs = args
-        else:
-            args = {}
-        return args, nuts_kwargs
-
-    def _pymc3_version(self):
-        pymc3, _, _ = self._import_external_sampler()
-        return pymc3.__version__
-
-    def set_prior(self):
-        """
-        Set the PyMC3 prior distributions.
-        """
-
-        self.setup_prior_mapping()
-
-        self.pymc3_priors = dict()
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-
-        # initialise a dictionary of multivariate Gaussian parameters
-        self.multivariate_normal_sets = {}
-        self.multivariate_normal_num_sets = 0
-
-        # set the parameter prior distributions (in the model context manager)
-        with self.pymc3_model:
-            for key in self.priors:
-                # if the prior contains ln_prob method that takes a 'sampler' argument
-                # then try using that
-                lnprobargs = infer_args_from_method(self.priors[key].ln_prob)
-                if 'sampler' in lnprobargs:
-                    try:
-                        self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self)
-                    except RuntimeError:
-                        raise RuntimeError(("Problem setting PyMC3 prior for ",
-                                            "'{}'".format(key)))
-                else:
-                    # use Prior distribution name
-                    distname = self.priors[key].__class__.__name__
-
-                    if distname in self.prior_map:
-                        # check if we have a predefined PyMC3 distribution
-                        if 'pymc3' in self.prior_map[distname] and 'argmap' in self.prior_map[distname]:
-                            # check the required arguments for the PyMC3 distribution
-                            pymc3distname = self.prior_map[distname]['pymc3']
-
-                            if pymc3distname not in pymc3.__dict__:
-                                raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname))
-
-                            reqargs = infer_args_from_method(pymc3.__dict__[pymc3distname].__init__)
-
-                            # set keyword arguments
-                            priorkwargs = {}
-                            for (targ, parg) in self.prior_map[distname]['argmap'].items():
-                                if hasattr(self.priors[key], targ):
-                                    if parg in reqargs:
-                                        if 'argtransform' in self.prior_map[distname]:
-                                            if targ in self.prior_map[distname]['argtransform']:
-                                                tfunc = self.prior_map[distname]['argtransform'][targ]
-                                            else:
-                                                def tfunc(x):
-                                                    return x
-                                        else:
-                                            def tfunc(x):
-                                                return x
-
-                                        priorkwargs[parg] = tfunc(getattr(self.priors[key], targ))
-                                    else:
-                                        raise ValueError("Unknown argument {}".format(parg))
-                                else:
-                                    if parg in reqargs:
-                                        priorkwargs[parg] = None
-                            self.pymc3_priors[key] = pymc3.__dict__[pymc3distname](key, **priorkwargs)
-                        elif 'internal' in self.prior_map[distname]:
-                            self.pymc3_priors[key] = self.prior_map[distname]['internal'](key)
-                        else:
-                            raise ValueError("Prior '{}' is not a known distribution.".format(distname))
-                    else:
-                        raise ValueError("Prior '{}' is not a known distribution.".format(distname))
-
-    def set_likelihood(self):
-        """
-        Convert any bilby likelihoods to PyMC3 distributions.
-        """
-
-        # create theano Op for the log likelihood if not using a predefined model
-        pymc3, STEP_METHODS, floatX = self._import_external_sampler()
-        theano, tt, as_op = self._import_theano()
-
-        class LogLike(tt.Op):
-
-            itypes = [tt.dvector]
-            otypes = [tt.dscalar]
-
-            def __init__(self, parameters, loglike, priors):
-                self.parameters = parameters
-                self.likelihood = loglike
-                self.priors = priors
-
-                # set the fixed parameters
-                for key in self.priors.keys():
-                    if isinstance(self.priors[key], float):
-                        self.likelihood.parameters[key] = self.priors[key]
-
-                self.logpgrad = LogLikeGrad(self.parameters, self.likelihood, self.priors)
-
-            def perform(self, node, inputs, outputs):
-                theta, = inputs
-                for i, key in enumerate(self.parameters):
-                    self.likelihood.parameters[key] = theta[i]
-
-                outputs[0][0] = np.array(self.likelihood.log_likelihood())
-
-            def grad(self, inputs, g):
-                theta, = inputs
-                return [g[0] * self.logpgrad(theta)]
-
-        # create theano Op for calculating the gradient of the log likelihood
-        class LogLikeGrad(tt.Op):
-
-            itypes = [tt.dvector]
-            otypes = [tt.dvector]
-
-            def __init__(self, parameters, loglike, priors):
-                self.parameters = parameters
-                self.Nparams = len(parameters)
-                self.likelihood = loglike
-                self.priors = priors
-
-                # set the fixed parameters
-                for key in self.priors.keys():
-                    if isinstance(self.priors[key], float):
-                        self.likelihood.parameters[key] = self.priors[key]
-
-            def perform(self, node, inputs, outputs):
-                theta, = inputs
-
-                # define version of likelihood function to pass to derivative function
-                def lnlike(values):
-                    for i, key in enumerate(self.parameters):
-                        self.likelihood.parameters[key] = values[i]
-                    return self.likelihood.log_likelihood()
-
-                # calculate gradients
-                grads = derivatives(theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2)
-
-                outputs[0][0] = grads
-
-        with self.pymc3_model:
-            #  check if it is a predefined likelhood function
-            if isinstance(self.likelihood, GaussianLikelihood):
-                # check required attributes exist
-                if (not hasattr(self.likelihood, 'sigma') or
-                        not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Gaussian Likelihood does not have all the correct attributes!")
-
-                if 'sigma' in self.pymc3_priors:
-                    # if sigma is suppled use that value
-                    if self.likelihood.sigma is None:
-                        self.likelihood.sigma = self.pymc3_priors.pop('sigma')
-                    else:
-                        del self.pymc3_priors['sigma']
-
-                for key in self.pymc3_priors:
-                    if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
-
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
-
-                # set the distribution
-                pymc3.Normal('likelihood', mu=model, sd=self.likelihood.sigma,
-                             observed=self.likelihood.y)
-            elif isinstance(self.likelihood, PoissonLikelihood):
-                # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Poisson Likelihood does not have all the correct attributes!")
-
-                for key in self.pymc3_priors:
-                    if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
-
-                # get rate function
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
-
-                # set the distribution
-                pymc3.Poisson('likelihood', mu=model, observed=self.likelihood.y)
-            elif isinstance(self.likelihood, ExponentialLikelihood):
-                # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Exponential Likelihood does not have all the correct attributes!")
-
-                for key in self.pymc3_priors:
-                    if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
-
-                # get mean function
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
-
-                # set the distribution
-                pymc3.Exponential('likelihood', lam=1. / model, observed=self.likelihood.y)
-            elif isinstance(self.likelihood, StudentTLikelihood):
-                # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y') or
-                        not hasattr(self.likelihood, 'nu') or
-                        not hasattr(self.likelihood, 'sigma')):
-                    raise ValueError("StudentT Likelihood does not have all the correct attributes!")
-
-                if 'nu' in self.pymc3_priors:
-                    # if nu is suppled use that value
-                    if self.likelihood.nu is None:
-                        self.likelihood.nu = self.pymc3_priors.pop('nu')
-                    else:
-                        del self.pymc3_priors['nu']
-
-                for key in self.pymc3_priors:
-                    if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
-
-                model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
-
-                # set the distribution
-                pymc3.StudentT('likelihood', nu=self.likelihood.nu, mu=model, sd=self.likelihood.sigma,
-                               observed=self.likelihood.y)
-            elif isinstance(self.likelihood, (GravitationalWaveTransient, BasicGravitationalWaveTransient)):
-                # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
-                logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc3_priors)
-
-                parameters = dict()
-                for key in self._search_parameter_keys:
-                    try:
-                        parameters[key] = self.pymc3_priors[key]
-                    except KeyError:
-                        raise KeyError(
-                            "Unknown key '{}' when setting GravitationalWaveTransient likelihood".format(key))
-
-                # convert to theano tensor variable
-                values = tt.as_tensor_variable(list(parameters.values()))
-
-                pymc3.DensityDist('likelihood', lambda v: logl(v), observed={'v': values})
-            else:
-                raise ValueError("Unknown likelihood has been provided")
diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index d9869362256ad60c8029acdd0e19020a31dea8f8..6f0349fe33964ad381675dbc713edfe2dcc4ab1f 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -1,20 +1,15 @@
+import datetime
 import importlib
 import os
-import shutil
-import distutils.dir_util
-import signal
 import time
-import datetime
-import sys
 
 import numpy as np
 
-from ..utils import check_directory_exists_and_if_not_mkdir
 from ..utils import logger
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
-class Pymultinest(NestedSampler):
+class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
     """
     bilby wrapper of pymultinest
     (https://github.com/JohannesBuchner/PyMultiNest)
@@ -65,6 +60,9 @@ class Pymultinest(NestedSampler):
         init_MPI=False,
         dump_callback=None,
     )
+    short_name = "pm"
+    hard_exit = True
+    sampling_seed_key = "seed"
 
     def __init__(
         self,
@@ -94,6 +92,7 @@ class Pymultinest(NestedSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
+            temporary_directory=temporary_directory,
             **kwargs
         )
         self._apply_multinest_boundaries()
@@ -105,11 +104,8 @@ class Pymultinest(NestedSampler):
             )
         self.use_temporary_directory = temporary_directory and not using_mpi
 
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "n_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
@@ -141,74 +137,7 @@ class Pymultinest(NestedSampler):
                 else:
                     self.kwargs["wrapped_params"].append(0)
 
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = "{}/pm_{}/".format(self.outdir, self.label)
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """Write current state and exit on exit_code"""
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        sys.exit(self.exit_code)
-
-    def _copy_temporary_directory_contents_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path.
-        Do not delete the temporary directory.
-        """
-        logger.info(
-            "Overwriting {} with {}".format(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-        )
-        if self.outputfiles_basename.endswith("/"):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(
-            self.temporary_outputfiles_basename, outputfiles_basename_stripped
-        )
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path
-
-        Anything in the temporary directory at this point is removed
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
-
+    @signal_wrapper
     def run_sampler(self):
         import pymultinest
 
@@ -247,27 +176,6 @@ class Pymultinest(NestedSampler):
         self.result.nested_samples = self._nested_samples
         return self.result
 
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, "r") as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-        self.start_time = current_time
-        with open(self.time_file_path, "w") as time_file:
-            time_file.write(str(self.total_sampling_time))
-
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-
     @property
     def _nested_samples(self):
         """
diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py
index 2348319e4b6048eb2669e798e94c6a06af9b6e0e..4cc14a9fa7ff9ba1ce4f19979571f8dde2c71e7b 100644
--- a/bilby/core/sampler/ultranest.py
+++ b/bilby/core/sampler/ultranest.py
@@ -1,20 +1,15 @@
-
 import datetime
-import distutils.dir_util
 import inspect
-import os
-import shutil
-import signal
 import time
 
 import numpy as np
 from pandas import DataFrame
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
-from .base_sampler import NestedSampler
+from ..utils import logger
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
-class Ultranest(NestedSampler):
+class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
     """
     bilby wrapper of ultranest
     (https://johannesbuchner.github.io/UltraNest/index.html)
@@ -73,6 +68,8 @@ class Ultranest(NestedSampler):
         step_sampler=None,
     )
 
+    short_name = "ultra"
+
     def __init__(
         self,
         likelihood,
@@ -96,31 +93,32 @@ class Ultranest(NestedSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
+            temporary_directory=temporary_directory,
             **kwargs,
         )
         self._apply_ultranest_boundaries()
-        self.use_temporary_directory = temporary_directory
 
         if self.use_temporary_directory:
             # set callback interval, so copying of results does not thrash the
             # disk (ultranest will call viz_callback quite a lot)
             self.callback_interval = callback_interval
 
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
     def _translate_kwargs(self, kwargs):
+        kwargs = super()._translate_kwargs(kwargs)
         if "num_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["num_live_points"] = kwargs.pop(equiv)
-
         if "verbose" in kwargs and "show_status" not in kwargs:
             kwargs["show_status"] = kwargs.pop("verbose")
+        resume = kwargs.get("resume", False)
+        if resume is True:
+            kwargs["resume"] = "overwrite"
+        elif resume is False:
+            kwargs["resume"] = "overwrite"
 
     def _verify_kwargs_against_default_kwargs(self):
-        """ Check the kwargs """
+        """Check the kwargs"""
 
         self.outputfiles_basename = self.kwargs.pop("log_dir", None)
         if self.kwargs["viz_callback"] is None:
@@ -148,76 +146,13 @@ class Ultranest(NestedSampler):
                     else:
                         self.kwargs["wrapped_params"].append(0)
 
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = os.path.join(
-                self.outdir, "ultra_{}/".format(self.label)
-            )
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """ Write current state and exit on exit_code """
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        os._exit(self.exit_code)
-
     def _copy_temporary_directory_contents_to_proper_path(self):
         """
         Copy the temporary back to the proper path.
         Do not delete the temporary directory.
         """
         if inspect.stack()[1].function != "_viz_callback":
-            logger.info(
-                "Overwriting {} with {}".format(
-                    self.outputfiles_basename, self.temporary_outputfiles_basename
-                )
-            )
-        if self.outputfiles_basename.endswith("/"):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(
-            self.temporary_outputfiles_basename, outputfiles_basename_stripped
-        )
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Move the temporary back to the proper path
-
-        Anything in the proper path at this point is removed including links
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
+            super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path()
 
     @property
     def sampler_function_kwargs(self):
@@ -271,6 +206,7 @@ class Ultranest(NestedSampler):
 
         return init_kwargs
 
+    @signal_wrapper
     def run_sampler(self):
         import ultranest
         import ultranest.stepsampler
@@ -285,7 +221,7 @@ class Ultranest(NestedSampler):
         stepsampler = self.kwargs.pop("step_sampler", None)
 
         self._setup_run_directory()
-        self.kwargs["log_dir"] = self.kwargs.pop("outputfiles_basename")
+        self.kwargs["log_dir"] = self.kwargs["outputfiles_basename"]
         self._check_and_load_sampling_time_file()
 
         # use reactive nested sampler when no live points are given
@@ -317,7 +253,6 @@ class Ultranest(NestedSampler):
         results = sampler.run(**self.sampler_function_kwargs)
         self._calculate_and_save_sampling_time()
 
-        # Clean up
         self._clean_up_run_directory()
 
         self._generate_result(results)
@@ -325,27 +260,6 @@ class Ultranest(NestedSampler):
 
         return self.result
 
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["log_dir"] = self.outputfiles_basename
-
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, "r") as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-        with open(self.time_file_path, "w") as time_file:
-            time_file.write(str(self.total_sampling_time))
-        self.start_time = current_time
-
     def _generate_result(self, out):
         # extract results
         data = np.array(out["weighted_samples"]["points"])
@@ -357,16 +271,22 @@ class Ultranest(NestedSampler):
         nested_samples = DataFrame(data, columns=self.search_parameter_keys)
         nested_samples["weights"] = weights
         nested_samples["log_likelihood"] = out["weighted_samples"]["logl"]
-        self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["logl"])[
-            mask
-        ]
+        self.result.log_likelihood_evaluations = np.array(
+            out["weighted_samples"]["logl"]
+        )[mask]
         self.result.sampler_output = out
         self.result.samples = data[mask, :]
         self.result.nested_samples = nested_samples
         self.result.log_evidence = out["logz"]
         self.result.log_evidence_err = out["logzerr"]
         if self.kwargs["num_live_points"] is not None:
-            self.result.information_gain = np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
+            self.result.information_gain = (
+                np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
+            )
 
         self.result.outputfiles_basename = self.outputfiles_basename
         self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
+
+    def log_likelihood(self, theta):
+        log_l = super(Ultranest, self).log_likelihood(theta=theta)
+        return np.nan_to_num(log_l)
diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py
index 78c3529ea00c1c9ee518a8698e2f70bc29fe194d..c7ae40da222201e5b29c53635c16c3edc94744f0 100644
--- a/bilby/core/sampler/zeus.py
+++ b/bilby/core/sampler/zeus.py
@@ -1,18 +1,17 @@
 import os
-import signal
 import shutil
-import sys
-from collections import namedtuple
 from shutil import copyfile
 
 import numpy as np
-from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import MCMCSampler, SamplerError
+from .base_sampler import SamplerError, signal_wrapper
+from .emcee import Emcee
+from .ptemcee import LikePriorEvaluator
 
+_evaluator = LikePriorEvaluator()
 
-class Zeus(MCMCSampler):
+
+class Zeus(Emcee):
     """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/)
 
     All positional and keyword arguments (i.e., the args and kwargs) passed to
@@ -65,12 +64,8 @@ class Zeus(MCMCSampler):
         burn_in_fraction=0.25,
         resume=True,
         burn_in_act=3,
-        **kwargs
+        **kwargs,
     ):
-        import zeus
-
-        self.zeus = zeus
-
         super(Zeus, self).__init__(
             likelihood=likelihood,
             priors=priors,
@@ -79,25 +74,16 @@ class Zeus(MCMCSampler):
             use_ratio=use_ratio,
             plot=plot,
             skip_import_verification=skip_import_verification,
-            **kwargs
+            pos0=pos0,
+            nburn=nburn,
+            burn_in_fraction=burn_in_fraction,
+            resume=resume,
+            burn_in_act=burn_in_act,
+            **kwargs,
         )
-        self.resume = resume
-        self.pos0 = pos0
-        self.nburn = nburn
-        self.burn_in_fraction = burn_in_fraction
-        self.burn_in_act = burn_in_act
-
-        signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
-        signal.signal(signal.SIGINT, self.checkpoint_and_exit)
 
     def _translate_kwargs(self, kwargs):
-        if "nwalkers" not in kwargs:
-            for equiv in self.nwalkers_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs["nwalkers"] = kwargs.pop(equiv)
-        if "iterations" not in kwargs:
-            if "nsteps" in kwargs:
-                kwargs["iterations"] = kwargs.pop("nsteps")
+        super(Zeus, self)._translate_kwargs(kwargs=kwargs)
 
         # check if using emcee-style arguments
         if "start" not in kwargs:
@@ -107,17 +93,6 @@ class Zeus(MCMCSampler):
             if "lnprob0" in kwargs:
                 kwargs["log_prob0"] = kwargs.pop("lnprob0")
 
-        if "threads" in kwargs:
-            if kwargs["threads"] != 1:
-                logger.warning(
-                    "The 'threads' argument cannot be used for "
-                    "parallelisation. This run will proceed "
-                    "without parallelisation, but consider the use "
-                    "of an appropriate Pool object passed to the "
-                    "'pool' keyword."
-                )
-                kwargs["threads"] = 1
-
     @property
     def sampler_function_kwargs(self):
         keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"]
@@ -134,168 +109,21 @@ class Zeus(MCMCSampler):
             if key not in self.sampler_function_kwargs
         }
 
-        init_kwargs["logprob_fn"] = self.lnpostfn
+        init_kwargs["logprob_fn"] = _evaluator.call_emcee
         init_kwargs["ndim"] = self.ndim
 
         return init_kwargs
 
-    def lnpostfn(self, theta):
-        log_prior = self.log_prior(theta)
-        if np.isinf(log_prior):
-            return -np.inf, [np.nan, np.nan]
-        else:
-            log_likelihood = self.log_likelihood(theta)
-            return log_likelihood + log_prior, [log_likelihood, log_prior]
-
-    @property
-    def nburn(self):
-        if type(self.__nburn) in [float, int]:
-            return int(self.__nburn)
-        elif self.result.max_autocorrelation_time is None:
-            return int(self.burn_in_fraction * self.nsteps)
-        else:
-            return int(self.burn_in_act * self.result.max_autocorrelation_time)
-
-    @nburn.setter
-    def nburn(self, nburn):
-        if isinstance(nburn, (float, int)):
-            if nburn > self.kwargs["iterations"] - 1:
-                raise ValueError(
-                    "Number of burn-in samples must be smaller "
-                    "than the total number of iterations"
-                )
-
-        self.__nburn = nburn
-
-    @property
-    def nwalkers(self):
-        return self.kwargs["nwalkers"]
-
-    @property
-    def nsteps(self):
-        return self.kwargs["iterations"]
-
-    @nsteps.setter
-    def nsteps(self, nsteps):
-        self.kwargs["iterations"] = nsteps
-
-    @property
-    def stored_chain(self):
-        """Read the stored zero-temperature chain data in from disk"""
-        return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
-
-    @property
-    def stored_samples(self):
-        """Returns the samples stored on disk"""
-        return self.stored_chain[self.search_parameter_keys]
-
-    @property
-    def stored_loglike(self):
-        """Returns the log-likelihood stored on disk"""
-        return self.stored_chain["log_l"]
-
-    @property
-    def stored_logprior(self):
-        """Returns the log-prior stored on disk"""
-        return self.stored_chain["log_p"]
-
-    def _init_chain_file(self):
-        with open(self.checkpoint_info.chain_file, "w+") as ff:
-            ff.write(
-                "walker\t{}\tlog_l\tlog_p\n".format(
-                    "\t".join(self.search_parameter_keys)
-                )
-            )
-
-    @property
-    def checkpoint_info(self):
-        """Defines various things related to checkpointing and storing data
-
-        Returns
-        =======
-        checkpoint_info: named_tuple
-            An object with attributes `sampler_file`, `chain_file`, and
-            `chain_template`. The first two give paths to where the sampler and
-            chain data is stored, the last a formatted-str-template with which
-            to write the chain data to disk
-
-        """
-        out_dir = os.path.join(
-            self.outdir, "{}_{}".format(self.__class__.__name__.lower(), self.label)
-        )
-        check_directory_exists_and_if_not_mkdir(out_dir)
-
-        chain_file = os.path.join(out_dir, "chain.dat")
-        sampler_file = os.path.join(out_dir, "sampler.pickle")
-        chain_template = (
-            "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n"
-        )
-
-        CheckpointInfo = namedtuple(
-            "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]
-        )
-
-        checkpoint_info = CheckpointInfo(
-            sampler_file=sampler_file,
-            chain_file=chain_file,
-            chain_template=chain_template,
-        )
-
-        return checkpoint_info
-
-    @property
-    def sampler_chain(self):
-        nsteps = self._previous_iterations
-        return self.sampler.chain[:, :nsteps, :]
-
-    def checkpoint(self):
-        """Writes a pickle file of the sampler to disk using dill"""
-        import dill
-
-        logger.info(
-            "Checkpointing sampler to file {}".format(self.checkpoint_info.sampler_file)
-        )
-        with open(self.checkpoint_info.sampler_file, "wb") as f:
-            # Overwrites the stored sampler chain with one that is truncated
-            # to only the completed steps
-            self.sampler._chain = self.sampler_chain
-            dill.dump(self._sampler, f)
-
-    def checkpoint_and_exit(self, signum, frame):
-        logger.info("Received signal {}".format(signum))
-        self.checkpoint()
-        sys.exit()
+    def write_current_state(self):
+        self._sampler.distribute = map
+        super(Zeus, self).write_current_state()
+        self._sampler.distribute = getattr(self._sampler.pool, "map", map)
 
     def _initialise_sampler(self):
-        self._sampler = self.zeus.EnsembleSampler(**self.sampler_init_kwargs)
-        self._init_chain_file()
+        from zeus import EnsembleSampler
 
-    @property
-    def sampler(self):
-        """Returns the Zeus sampler object
-
-        If, already initialized, returns the stored _sampler value. Otherwise,
-        first checks if there is a pickle file from which to load. If there is
-        not, then initialize the sampler and set the initial random draw
-
-        """
-        if hasattr(self, "_sampler"):
-            pass
-        elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
-            import dill
-
-            logger.info(
-                "Resuming run from checkpoint file {}".format(
-                    self.checkpoint_info.sampler_file
-                )
-            )
-            with open(self.checkpoint_info.sampler_file, "rb") as f:
-                self._sampler = dill.load(f)
-            self._set_pos0_for_resume()
-        else:
-            self._initialise_sampler()
-            self._set_pos0()
-        return self._sampler
+        self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
+        self._init_chain_file()
 
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
@@ -310,48 +138,12 @@ class Zeus(MCMCSampler):
                 ff.write(self.checkpoint_info.chain_template.format(ii, *point))
         shutil.move(temp_chain_file, chain_file)
 
-    @property
-    def _previous_iterations(self):
-        """Returns the number of iterations that the sampler has saved
-
-        This is used when loading in a sampler from a pickle file to figure out
-        how much of the run has already been completed
-        """
-        try:
-            return len(self.sampler.get_blobs())
-        except AttributeError:
-            return 0
-
-    def _draw_pos0_from_prior(self):
-        return np.array(
-            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
-        )
-
-    @property
-    def _pos0_shape(self):
-        return (self.nwalkers, self.ndim)
-
-    def _set_pos0(self):
-        if self.pos0 is not None:
-            logger.debug("Using given initial positions for walkers")
-            if isinstance(self.pos0, DataFrame):
-                self.pos0 = self.pos0[self.search_parameter_keys].values
-            elif type(self.pos0) in (list, np.ndarray):
-                self.pos0 = np.squeeze(self.pos0)
-
-            if self.pos0.shape != self._pos0_shape:
-                raise ValueError("Input pos0 should be of shape ndim, nwalkers")
-            logger.debug("Checking input pos0")
-            for draw in self.pos0:
-                self.check_draw(draw)
-        else:
-            logger.debug("Generating initial walker positions from prior")
-            self.pos0 = self._draw_pos0_from_prior()
-
     def _set_pos0_for_resume(self):
         self.pos0 = self.sampler.get_last_sample()
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         sampler_function_kwargs = self.sampler_function_kwargs
         iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
@@ -363,7 +155,8 @@ class Zeus(MCMCSampler):
             iterations=iterations, **sampler_function_kwargs
         ):
             self.write_chains_to_file(sample)
-        self.checkpoint()
+        self._close_pool()
+        self.write_current_state()
 
         self.result.sampler_output = np.nan
         self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
@@ -381,10 +174,12 @@ class Zeus(MCMCSampler):
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps` ({} < {}). Try increasing the "
-                "number of steps.".format(self.result.nburn, self.nsteps)
+                f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
+                " Try increasing the number of steps."
             )
-        blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape((-1, 2))
+        blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape(
+            (-1, 2)
+        )
         log_likelihoods, log_priors = blobs.T
         self.result.log_likelihood_evaluations = log_likelihoods
         self.result.log_prior_evaluations = log_priors
diff --git a/bilby/core/utils/cmd.py b/bilby/core/utils/cmd.py
index 2305d7fee2ff59fe2ca4ce939e119cc354a0f7cc..eba784bb8dc2b66ced02f08598e5930a5322136d 100644
--- a/bilby/core/utils/cmd.py
+++ b/bilby/core/utils/cmd.py
@@ -111,7 +111,7 @@ def run_commandline(cl, log_level=20, raise_error=True, return_output=True):
             else:
                 out = 0
         os.system('\n')
-        return(out)
+        return out
     else:
         process = subprocess.Popen(cl, shell=True)
         process.communicate()
diff --git a/bilby/core/utils/introspection.py b/bilby/core/utils/introspection.py
index 3d92b5f7e644b6b378075fcf45b9582500db657d..70073f995916f18d41a27bc060818acde2023d0e 100644
--- a/bilby/core/utils/introspection.py
+++ b/bilby/core/utils/introspection.py
@@ -94,11 +94,8 @@ def infer_args_from_function_except_n_args(func, n=1):
         ['c', 'd']
 
     """
-    try:
-        parameters = inspect.getfullargspec(func).args
-    except AttributeError:
-        parameters = inspect.getargspec(func).args
-    del(parameters[:n])
+    parameters = inspect.getfullargspec(func).args
+    del parameters[:n]
     return parameters
 
 
diff --git a/bilby/core/utils/log.py b/bilby/core/utils/log.py
index d9db7581c9bf4a08d50bf4329c30d2b0f628e293..dd9bcfd95b471bbe70f7300758ea317d31051026 100644
--- a/bilby/core/utils/log.py
+++ b/bilby/core/utils/log.py
@@ -1,5 +1,4 @@
 import logging
-import os
 from pathlib import Path
 import sys
 
@@ -60,13 +59,8 @@ def setup_logger(outdir='.', label=None, log_level='INFO', print_version=False):
 
 
 def get_version_information():
-    version_file = os.path.join(
-        os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '.version')
-    try:
-        with open(version_file, 'r') as f:
-            return f.readline().rstrip()
-    except EnvironmentError:
-        print("No version information file '.version' found")
+    from bilby import __version__
+    return __version__
 
 
 def loaded_modules_dict():
diff --git a/bilby/gw/__init__.py b/bilby/gw/__init__.py
index 3242f0ebd92a0b0c3b5c6f14ae1f27c80de2d08e..b5115766b2c421fc6d1585809c44570a3c6bdb9e 100644
--- a/bilby/gw/__init__.py
+++ b/bilby/gw/__init__.py
@@ -1,6 +1,6 @@
 from . import (conversion, cosmology, detector, eos, likelihood, prior,
                result, source, utils, waveform_generator)
-from .waveform_generator import WaveformGenerator
+from .waveform_generator import WaveformGenerator, LALCBCWaveformGenerator
 from .likelihood import GravitationalWaveTransient
 from .detector import calibration
 
diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index d52e05ae8d21f7bff9054da180dcc3fb1ecbeb6d..0a2c7039b3bb0c6fb2eb1976053371103a5effd7 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -210,67 +210,8 @@ def convert_to_lal_binary_black_hole_parameters(parameters):
             converted_parameters[key[:-7]] = converted_parameters[key] * (
                 1 + converted_parameters['redshift'])
 
-    if 'chirp_mass' in converted_parameters.keys():
-        if "mass_1" in converted_parameters.keys():
-            converted_parameters["mass_ratio"] = chirp_mass_and_primary_mass_to_mass_ratio(
-                converted_parameters["chirp_mass"], converted_parameters["mass_1"])
-        if 'total_mass' in converted_parameters.keys():
-            converted_parameters['symmetric_mass_ratio'] =\
-                chirp_mass_and_total_mass_to_symmetric_mass_ratio(
-                    converted_parameters['chirp_mass'],
-                    converted_parameters['total_mass'])
-        if 'symmetric_mass_ratio' in converted_parameters.keys() and "mass_ratio" not in converted_parameters:
-            converted_parameters['mass_ratio'] =\
-                symmetric_mass_ratio_to_mass_ratio(
-                    converted_parameters['symmetric_mass_ratio'])
-        if 'total_mass' not in converted_parameters.keys():
-            converted_parameters['total_mass'] =\
-                chirp_mass_and_mass_ratio_to_total_mass(
-                    converted_parameters['chirp_mass'],
-                    converted_parameters['mass_ratio'])
-        converted_parameters['mass_1'], converted_parameters['mass_2'] = \
-            total_mass_and_mass_ratio_to_component_masses(
-                converted_parameters['mass_ratio'],
-                converted_parameters['total_mass'])
-    elif 'total_mass' in converted_parameters.keys():
-        if 'symmetric_mass_ratio' in converted_parameters.keys():
-            converted_parameters['mass_ratio'] = \
-                symmetric_mass_ratio_to_mass_ratio(
-                    converted_parameters['symmetric_mass_ratio'])
-        if 'mass_ratio' in converted_parameters.keys():
-            converted_parameters['mass_1'], converted_parameters['mass_2'] =\
-                total_mass_and_mass_ratio_to_component_masses(
-                    converted_parameters['mass_ratio'],
-                    converted_parameters['total_mass'])
-        elif 'mass_1' in converted_parameters.keys():
-            converted_parameters['mass_2'] =\
-                converted_parameters['total_mass'] -\
-                converted_parameters['mass_1']
-        elif 'mass_2' in converted_parameters.keys():
-            converted_parameters['mass_1'] = \
-                converted_parameters['total_mass'] - \
-                converted_parameters['mass_2']
-    elif 'symmetric_mass_ratio' in converted_parameters.keys():
-        converted_parameters['mass_ratio'] =\
-            symmetric_mass_ratio_to_mass_ratio(
-                converted_parameters['symmetric_mass_ratio'])
-        if 'mass_1' in converted_parameters.keys():
-            converted_parameters['mass_2'] =\
-                converted_parameters['mass_1'] *\
-                converted_parameters['mass_ratio']
-        elif 'mass_2' in converted_parameters.keys():
-            converted_parameters['mass_1'] =\
-                converted_parameters['mass_2'] /\
-                converted_parameters['mass_ratio']
-    elif 'mass_ratio' in converted_parameters.keys():
-        if 'mass_1' in converted_parameters.keys():
-            converted_parameters['mass_2'] =\
-                converted_parameters['mass_1'] *\
-                converted_parameters['mass_ratio']
-        if 'mass_2' in converted_parameters.keys():
-            converted_parameters['mass_1'] = \
-                converted_parameters['mass_2'] /\
-                converted_parameters['mass_ratio']
+    # we do not require the component masses be added if no mass parameters are present
+    converted_parameters = generate_component_masses(converted_parameters, require_add=False)
 
     for idx in ['1', '2']:
         key = 'chi_{}'.format(idx)
@@ -480,6 +421,33 @@ def total_mass_and_mass_ratio_to_component_masses(mass_ratio, total_mass):
     return mass_1, mass_2
 
 
+def chirp_mass_and_mass_ratio_to_component_masses(chirp_mass, mass_ratio):
+    """
+    Convert total mass and mass ratio of a binary to its component masses.
+
+    Parameters
+    ==========
+    chirp_mass: float
+        Chirp mass of the binary
+    mass_ratio: float
+        Mass ratio (mass_2/mass_1) of the binary
+
+    Returns
+    =======
+    mass_1: float
+        Mass of the heavier object
+    mass_2: float
+        Mass of the lighter object
+    """
+    total_mass = chirp_mass_and_mass_ratio_to_total_mass(chirp_mass=chirp_mass,
+                                                         mass_ratio=mass_ratio)
+    mass_1, mass_2 = (
+        total_mass_and_mass_ratio_to_component_masses(
+            total_mass=total_mass, mass_ratio=mass_ratio)
+    )
+    return mass_1, mass_2
+
+
 def symmetric_mass_ratio_to_mass_ratio(symmetric_mass_ratio):
     """
     Convert the symmetric mass ratio to the normal mass ratio.
@@ -678,6 +646,30 @@ def mass_1_and_chirp_mass_to_mass_ratio(mass_1, chirp_mass):
     return mass_ratio
 
 
+def mass_2_and_chirp_mass_to_mass_ratio(mass_2, chirp_mass):
+    """
+    Calculate mass ratio from mass_1 and chirp_mass.
+
+    This involves solving mc = m2 * (1/q)**(3/5) / (1 + (1/q))**(1/5).
+
+    Parameters
+    ==========
+    mass_2: float
+        Mass of the lighter object
+    chirp_mass: float
+        Chirp mass of the binary
+
+    Returns
+    =======
+    mass_ratio: float
+        Mass ratio of the binary
+    """
+    # Passing mass_2, the expression from the function above
+    # returns 1/q (because chirp mass is invariant under
+    # mass_1 <-> mass_2)
+    return 1 / mass_1_and_chirp_mass_to_mass_ratio(mass_2, chirp_mass)
+
+
 def lambda_1_lambda_2_to_lambda_tilde(lambda_1, lambda_2, mass_1, mass_2):
     """
     Convert from individual tidal parameters to domainant tidal term.
@@ -838,12 +830,8 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion,
     output_sample = fill_from_fixed_priors(output_sample, priors)
     output_sample, _ = base_conversion(output_sample)
     if likelihood is not None:
-        if (
-                hasattr(likelihood, 'phase_marginalization') or
-                hasattr(likelihood, 'time_marginalization') or
-                hasattr(likelihood, 'distance_marginalization') or
-                hasattr(likelihood, 'calibration_marginalization')
-        ):
+        marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
+        if len(marginalized_parameters) > 0:
             try:
                 generate_posterior_samples_from_marginalized_likelihood(
                     samples=output_sample, likelihood=likelihood, npool=npool)
@@ -854,10 +842,17 @@ def _generate_all_cbc_parameters(sample, defaults, base_conversion,
                     "interpretation.".format(e)
                 )
         if priors is not None:
-            for par, name in zip(
-                    ['distance', 'phase', 'time'],
-                    ['luminosity_distance', 'phase', 'geocent_time']):
-                if getattr(likelihood, '{}_marginalization'.format(par), False):
+            misnamed_marginalizations = dict(
+                distance="luminosity_distance",
+                time="geocent_time",
+                calibration="recalib_index",
+            )
+            for par in marginalized_parameters:
+                name = misnamed_marginalizations.get(par, par)
+                if (
+                    getattr(likelihood, f'{par}_marginalization', False)
+                    and name in likelihood.priors
+                ):
                     priors[name] = likelihood.priors[name]
 
         if (
@@ -1006,33 +1001,220 @@ def fill_from_fixed_priors(sample, priors):
     return output_sample
 
 
-def generate_mass_parameters(sample):
+def generate_component_masses(sample, require_add=False, source=False):
+    """"
+    Add the component masses to the dataframe/dictionary
+    We add:
+        mass_1, mass_2
+    We also add any other masses which may be necessary for
+    intermediate steps, i.e. typically the  total mass is necessary, along
+    with the mass ratio, so these will usually be added to the dictionary
+
+    If `require_add` is True, then having an incomplete set of mass
+    parameters (so that the component mass parameters cannot be added)
+    will throw an error, otherwise it will quietly add nothing to the
+    dictionary.
+
+    Parameters
+    =========
+    sample : dict
+        The input dictionary with at least one
+        component with overall mass scaling (i.e.
+        chirp_mass, mass_1, mass_2, total_mass) and
+        then any other mass parameter.
+    source : bool, default False
+        If True, then perform the conversions for source mass parameters
+        i.e. mass_1_source instead of mass_1
+
+    Returns
+    dict : the updated dictionary
+    """
+    def check_and_return_quietly(require_add, sample):
+        if require_add:
+            raise KeyError("Insufficient mass parameters in input dictionary")
+        else:
+            return sample
+    output_sample = sample.copy()
+
+    if source:
+        mass_1_key = "mass_1_source"
+        mass_2_key = "mass_2_source"
+        total_mass_key = "total_mass_source"
+        chirp_mass_key = "chirp_mass_source"
+    else:
+        mass_1_key = "mass_1"
+        mass_2_key = "mass_2"
+        total_mass_key = "total_mass"
+        chirp_mass_key = "chirp_mass"
+
+    if mass_1_key in sample.keys():
+        if mass_2_key in sample.keys():
+            return output_sample
+        if total_mass_key in sample.keys():
+            output_sample[mass_2_key] = output_sample[total_mass_key] - (
+                output_sample[mass_1_key]
+            )
+            return output_sample
+
+        elif "mass_ratio" in sample.keys():
+            pass
+        elif "symmetric_mass_ratio" in sample.keys():
+            output_sample["mass_ratio"] = (
+                symmetric_mass_ratio_to_mass_ratio(
+                    output_sample["symmetric_mass_ratio"])
+            )
+        elif chirp_mass_key in sample.keys():
+            output_sample["mass_ratio"] = (
+                mass_1_and_chirp_mass_to_mass_ratio(
+                    mass_1=output_sample[mass_1_key],
+                    chirp_mass=output_sample[chirp_mass_key])
+            )
+        else:
+            return check_and_return_quietly(require_add, sample)
+
+        output_sample[mass_2_key] = (
+            output_sample["mass_ratio"] * output_sample[mass_1_key]
+        )
+
+        return output_sample
+
+    elif mass_2_key in sample.keys():
+        # mass_1 is not in the dict
+        if total_mass_key in sample.keys():
+            output_sample[mass_1_key] = (
+                output_sample[total_mass_key] - output_sample[mass_2_key]
+            )
+            return output_sample
+        elif "mass_ratio" in sample.keys():
+            pass
+        elif "symmetric_mass_ratio" in sample.keys():
+            output_sample["mass_ratio"] = (
+                symmetric_mass_ratio_to_mass_ratio(
+                    output_sample["symmetric_mass_ratio"])
+            )
+        elif chirp_mass_key in sample.keys():
+            output_sample["mass_ratio"] = (
+                mass_2_and_chirp_mass_to_mass_ratio(
+                    mass_2=output_sample[mass_2_key],
+                    chirp_mass=output_sample[chirp_mass_key])
+            )
+        else:
+            check_and_return_quietly(require_add, sample)
+
+        output_sample[mass_1_key] = 1 / output_sample["mass_ratio"] * (
+            output_sample[mass_2_key]
+        )
+
+        return output_sample
+
+    # Only if neither mass_1 or mass_2 is in the input sample
+    if total_mass_key in sample.keys():
+        if "mass_ratio" in sample.keys():
+            pass  # We have everything we need already
+        elif "symmetric_mass_ratio" in sample.keys():
+            output_sample["mass_ratio"] = (
+                symmetric_mass_ratio_to_mass_ratio(
+                    output_sample["symmetric_mass_ratio"])
+            )
+        elif chirp_mass_key in sample.keys():
+            output_sample["symmetric_mass_ratio"] = (
+                chirp_mass_and_total_mass_to_symmetric_mass_ratio(
+                    chirp_mass=output_sample[chirp_mass_key],
+                    total_mass=output_sample[total_mass_key])
+            )
+            output_sample["mass_ratio"] = (
+                symmetric_mass_ratio_to_mass_ratio(
+                    output_sample["symmetric_mass_ratio"])
+            )
+        else:
+            return check_and_return_quietly(require_add, sample)
+
+    elif chirp_mass_key in sample.keys():
+        if "mass_ratio" in sample.keys():
+            pass
+        elif "symmetric_mass_ratio" in sample.keys():
+            output_sample["mass_ratio"] = (
+                symmetric_mass_ratio_to_mass_ratio(
+                    sample["symmetric_mass_ratio"])
+            )
+        else:
+            return check_and_return_quietly(require_add, sample)
+
+        output_sample[total_mass_key] = (
+            chirp_mass_and_mass_ratio_to_total_mass(
+                chirp_mass=output_sample[chirp_mass_key],
+                mass_ratio=output_sample["mass_ratio"])
+        )
+
+    # We haven't matched any of the criteria
+    if total_mass_key not in output_sample.keys() or (
+            "mass_ratio" not in output_sample.keys()):
+        return check_and_return_quietly(require_add, sample)
+    mass_1, mass_2 = (
+        total_mass_and_mass_ratio_to_component_masses(
+            total_mass=output_sample[total_mass_key],
+            mass_ratio=output_sample["mass_ratio"])
+    )
+    output_sample[mass_1_key] = mass_1
+    output_sample[mass_2_key] = mass_2
+
+    return output_sample
+
+
+def generate_mass_parameters(sample, source=False):
     """
-    Add the known mass parameters to the data frame/dictionary.
+    Add the known mass parameters to the data frame/dictionary.  We do
+    not recompute keys already present in the dictionary
 
-    We add:
-        chirp mass, total mass, symmetric mass ratio, mass ratio
+    We add, potentially:
+        chirp mass, total mass, symmetric mass ratio, mass ratio, mass_1, mass_2
 
     Parameters
     ==========
     sample : dict
-        The input dictionary with component masses 'mass_1' and 'mass_2'
-
+        The input dictionary with two "spanning" mass parameters
+        e.g. (mass_1, mass_2), or (chirp_mass, mass_ratio), but not e.g. only
+        (mass_ratio, symmetric_mass_ratio)
     Returns
     =======
     dict: The updated dictionary
 
     """
-    output_sample = sample.copy()
-    output_sample['chirp_mass'] =\
-        component_masses_to_chirp_mass(sample['mass_1'], sample['mass_2'])
-    output_sample['total_mass'] =\
-        component_masses_to_total_mass(sample['mass_1'], sample['mass_2'])
-    output_sample['symmetric_mass_ratio'] =\
-        component_masses_to_symmetric_mass_ratio(sample['mass_1'],
-                                                 sample['mass_2'])
-    output_sample['mass_ratio'] =\
-        component_masses_to_mass_ratio(sample['mass_1'], sample['mass_2'])
+    # Only add the parameters if they're not already present
+    intermediate_sample = generate_component_masses(sample, source=source)
+    output_sample = intermediate_sample.copy()
+
+    if source:
+        mass_1_key = 'mass_1_source'
+        mass_2_key = 'mass_2_source'
+        total_mass_key = 'total_mass_source'
+        chirp_mass_key = 'chirp_mass_source'
+    else:
+        mass_1_key = 'mass_1'
+        mass_2_key = 'mass_2'
+        total_mass_key = 'total_mass'
+        chirp_mass_key = 'chirp_mass'
+
+    if chirp_mass_key not in output_sample.keys():
+        output_sample[chirp_mass_key] = (
+            component_masses_to_chirp_mass(output_sample[mass_1_key],
+                                           output_sample[mass_2_key])
+        )
+    if total_mass_key not in output_sample.keys():
+        output_sample[total_mass_key] = (
+            component_masses_to_total_mass(output_sample[mass_1_key],
+                                           output_sample[mass_2_key])
+        )
+    if 'symmetric_mass_ratio' not in output_sample.keys():
+        output_sample['symmetric_mass_ratio'] = (
+            component_masses_to_symmetric_mass_ratio(output_sample[mass_1_key],
+                                                     output_sample[mass_2_key])
+        )
+    if 'mass_ratio' not in output_sample.keys():
+        output_sample['mass_ratio'] = (
+            component_masses_to_mass_ratio(output_sample[mass_1_key],
+                                           output_sample[mass_2_key])
+        )
 
     return output_sample
 
@@ -1227,9 +1409,14 @@ def compute_snrs(sample, likelihood, npool=1):
             from tqdm.auto import tqdm
             logger.info('Computing SNRs for every sample.')
 
-            fill_args = [(ii, row, likelihood) for ii, row in sample.iterrows()]
+            fill_args = [(ii, row) for ii, row in sample.iterrows()]
             if npool > 1:
-                pool = multiprocessing.Pool(processes=npool)
+                from ..core.sampler.base_sampler import _initialize_global_variables
+                pool = multiprocessing.Pool(
+                    processes=npool,
+                    initializer=_initialize_global_variables,
+                    initargs=(likelihood, None, None, False),
+                )
                 logger.info(
                     "Using a pool with size {} for nsamples={}".format(npool, len(sample))
                 )
@@ -1237,6 +1424,8 @@ def compute_snrs(sample, likelihood, npool=1):
                 pool.close()
                 pool.join()
             else:
+                from ..core.sampler.base_sampler import _sampling_convenience_dump
+                _sampling_convenience_dump.likelihood = likelihood
                 new_samples = [_compute_snrs(xx) for xx in tqdm(fill_args, file=sys.stdout)]
 
             for ii, ifo in enumerate(likelihood.interferometers):
@@ -1257,7 +1446,9 @@ def compute_snrs(sample, likelihood, npool=1):
 
 def _compute_snrs(args):
     """A wrapper of computing the SNRs to enable multiprocessing"""
-    ii, sample, likelihood = args
+    from ..core.sampler.base_sampler import _sampling_convenience_dump
+    likelihood = _sampling_convenience_dump.likelihood
+    ii, sample = args
     sample = dict(sample).copy()
     likelihood.parameters.update(sample)
     signal_polarizations = likelihood.waveform_generator.frequency_domain_strain(
@@ -1296,10 +1487,8 @@ def generate_posterior_samples_from_marginalized_likelihood(
     sample: DataFrame
         Returns the posterior with new samples.
     """
-    if not any([likelihood.phase_marginalization,
-                likelihood.distance_marginalization,
-                likelihood.time_marginalization,
-                likelihood.calibration_marginalization]):
+    marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
+    if len(marginalized_parameters) == 0:
         return samples
 
     # pass through a dictionary
@@ -1318,11 +1507,15 @@ def generate_posterior_samples_from_marginalized_likelihood(
         use_cache = False
 
     if use_cache and os.path.exists(cache_filename) and not command_line_args.clean:
-        with open(cache_filename, "rb") as f:
-            cached_samples_dict = pickle.load(f)
+        try:
+            with open(cache_filename, "rb") as f:
+                cached_samples_dict = pickle.load(f)
+        except EOFError:
+            logger.warning("Cache file is empty")
+            cached_samples_dict = None
 
         # Check the samples are identical between the cache and current
-        if cached_samples_dict["_samples"].equals(samples):
+        if (cached_samples_dict is not None) and (cached_samples_dict["_samples"].equals(samples)):
             # Calculate reconstruction percentage and print a log message
             nsamples_converted = np.sum(
                 [len(val) for key, val in cached_samples_dict.items() if key != "_samples"]
@@ -1342,15 +1535,22 @@ def generate_posterior_samples_from_marginalized_likelihood(
 
     # Set up the multiprocessing
     if npool > 1:
-        pool = multiprocessing.Pool(processes=npool)
+        from ..core.sampler.base_sampler import _initialize_global_variables
+        pool = multiprocessing.Pool(
+            processes=npool,
+            initializer=_initialize_global_variables,
+            initargs=(likelihood, None, None, False),
+        )
         logger.info(
             "Using a pool with size {} for nsamples={}"
             .format(npool, len(samples))
         )
     else:
+        from ..core.sampler.base_sampler import _sampling_convenience_dump
+        _sampling_convenience_dump.likelihood = likelihood
         pool = None
 
-    fill_args = [(ii, row, likelihood) for ii, row in samples.iterrows()]
+    fill_args = [(ii, row) for ii, row in samples.iterrows()]
     ii = 0
     pbar = tqdm(total=len(samples), file=sys.stdout)
     while ii < len(samples):
@@ -1382,11 +1582,8 @@ def generate_posterior_samples_from_marginalized_likelihood(
         [np.array(val) for key, val in cached_samples_dict.items() if key != "_samples"]
     )
 
-    samples['geocent_time'] = new_samples[:, 0]
-    samples['luminosity_distance'] = new_samples[:, 1]
-    samples['phase'] = new_samples[:, 2]
-    if likelihood.calibration_marginalization:
-        samples['recalib_index'] = new_samples[:, 3]
+    for ii, key in enumerate(marginalized_parameters):
+        samples[key] = new_samples[:, ii]
 
     return samples
 
@@ -1412,14 +1609,11 @@ def generate_sky_frame_parameters(samples, likelihood):
 
 
 def fill_sample(args):
-    ii, sample, likelihood = args
+    from ..core.sampler.base_sampler import _sampling_convenience_dump
+    likelihood = _sampling_convenience_dump.likelihood
+    ii, sample = args
+    marginalized_parameters = getattr(likelihood, "_marginalized_parameters", list())
     sample = dict(sample).copy()
     likelihood.parameters.update(dict(sample).copy())
     new_sample = likelihood.generate_posterior_sample_from_marginalized_likelihood()
-
-    if not likelihood.calibration_marginalization:
-        return new_sample["geocent_time"], new_sample["luminosity_distance"],\
-            new_sample["phase"]
-    else:
-        return new_sample["geocent_time"], new_sample["luminosity_distance"],\
-            new_sample["phase"], new_sample['recalib_index']
+    return tuple((new_sample[key] for key in marginalized_parameters))
diff --git a/bilby/gw/detector/geometry.py b/bilby/gw/detector/geometry.py
index 5ed8ec2e789c86396e32f4a12ae1f1f19412e41c..d7e1433decbccaef03990c8dc9f1fa99a2350c2e 100644
--- a/bilby/gw/detector/geometry.py
+++ b/bilby/gw/detector/geometry.py
@@ -1,4 +1,5 @@
 import numpy as np
+from bilby_cython.geometry import calculate_arm, detector_tensor
 
 from .. import utils as gwutils
 
@@ -263,7 +264,7 @@ class InterferometerGeometry(object):
         if not self._x_updated or not self._y_updated:
             _, _ = self.x, self.y  # noqa
         if not self._detector_tensor_updated:
-            self._detector_tensor = 0.5 * (np.einsum('i,j->ij', self.x, self.x) - np.einsum('i,j->ij', self.y, self.y))
+            self._detector_tensor = detector_tensor(x=self.x, y=self.y)
             self._detector_tensor_updated = True
         return self._detector_tensor
 
@@ -288,19 +289,18 @@ class InterferometerGeometry(object):
 
         """
         if arm == 'x':
-            return self._calculate_arm(self._xarm_tilt, self._xarm_azimuth)
+            return calculate_arm(
+                arm_tilt=self._xarm_tilt,
+                arm_azimuth=self._xarm_azimuth,
+                longitude=self._longitude,
+                latitude=self._latitude
+            )
         elif arm == 'y':
-            return self._calculate_arm(self._yarm_tilt, self._yarm_azimuth)
+            return calculate_arm(
+                arm_tilt=self._yarm_tilt,
+                arm_azimuth=self._yarm_azimuth,
+                longitude=self._longitude,
+                latitude=self._latitude
+            )
         else:
             raise ValueError("Arm must either be 'x' or 'y'.")
-
-    def _calculate_arm(self, arm_tilt, arm_azimuth):
-        e_long = np.array([-np.sin(self._longitude), np.cos(self._longitude), 0])
-        e_lat = np.array([-np.sin(self._latitude) * np.cos(self._longitude),
-                          -np.sin(self._latitude) * np.sin(self._longitude), np.cos(self._latitude)])
-        e_h = np.array([np.cos(self._latitude) * np.cos(self._longitude),
-                        np.cos(self._latitude) * np.sin(self._longitude), np.sin(self._latitude)])
-
-        return (np.cos(arm_tilt) * np.cos(arm_azimuth) * e_long +
-                np.cos(arm_tilt) * np.sin(arm_azimuth) * e_lat +
-                np.sin(arm_tilt) * e_h)
diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py
index 06b298386e3d2d8a8887f89af3d65f3e58919efd..08290f2e578a97e7afcfbd49b32120425f52b48c 100644
--- a/bilby/gw/detector/interferometer.py
+++ b/bilby/gw/detector/interferometer.py
@@ -1,6 +1,11 @@
 import os
 
 import numpy as np
+from bilby_cython.geometry import (
+    get_polarization_tensor,
+    three_by_three_matrix_contraction,
+    time_delay_from_geocenter,
+)
 
 from ...core import utils
 from ...core.utils import docstring, logger, PropertyAccessor
@@ -268,11 +273,11 @@ class Interferometer(object):
 
         Returns
         =======
-        array_like: A 3x3 array representation of the antenna response for the specified mode
+        float: The antenna response for the specified mode and time/location
 
         """
-        polarization_tensor = gwutils.get_polarization_tensor(ra, dec, time, psi, mode)
-        return np.einsum('ij,ij->', self.geometry.detector_tensor, polarization_tensor)
+        polarization_tensor = get_polarization_tensor(ra, dec, time, psi, mode)
+        return three_by_three_matrix_contraction(self.geometry.detector_tensor, polarization_tensor)
 
     def get_detector_response(self, waveform_polarizations, parameters):
         """ Get the detector response for a particular waveform
@@ -527,7 +532,7 @@ class Interferometer(object):
         =======
         float: The time delay from geocenter in seconds
         """
-        return gwutils.time_delay_geocentric(self.geometry.vertex, np.array([0, 0, 0]), ra, dec, time)
+        return time_delay_from_geocenter(self.geometry.vertex, ra, dec, time)
 
     def vertex_position_geocentric(self):
         """
diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index d4d8eca345bb8f6dfca921f4cd20bfcbc3c5cb6c..6ed82c7ba6e9789a2a6f0b922d0ee0b12b5cecd7 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -448,8 +448,7 @@ class GravitationalWaveTransient(Likelihood):
         This involves a deepcopy of the signal to avoid issues with waveform
         caching, as the signal is overwritten in place.
         """
-        if any([self.phase_marginalization, self.distance_marginalization,
-                self.time_marginalization, self.calibration_marginalization]):
+        if len(self._marginalized_parameters) > 0:
             signal_polarizations = copy.deepcopy(
                 self.waveform_generator.frequency_domain_strain(
                     self.parameters))
@@ -898,8 +897,9 @@ class GravitationalWaveTransient(Likelihood):
         for key in pairs:
             if key not in loaded_file:
                 return False, key
-            elif not np.array_equal(np.atleast_1d(loaded_file[key]),
-                                    np.atleast_1d(pairs[key])):
+            elif not np.allclose(np.atleast_1d(loaded_file[key]),
+                                 np.atleast_1d(pairs[key]),
+                                 rtol=1e-15):
                 return False, key
         return True, None
 
diff --git a/bilby/gw/likelihood/multiband.py b/bilby/gw/likelihood/multiband.py
index f361742c123af9cd5f27d236b76d1f3e1a4efc7b..9d0b7e1d154774a95469154918ab55601bf626a3 100644
--- a/bilby/gw/likelihood/multiband.py
+++ b/bilby/gw/likelihood/multiband.py
@@ -570,7 +570,7 @@ class MBGravitationalWaveTransient(GravitationalWaveTransient):
             self.banded_frequency_points, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
 
         strain *= np.exp(-1j * 2. * np.pi * self.banded_frequency_points * ifo_time)
-        strain *= np.conjugate(calib_factor)
+        strain *= calib_factor
 
         d_inner_h = np.dot(strain, self.linear_coeffs[interferometer.name])
 
diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py
index 42cfd83140b040ec5867e9088a391079cf7a9330..f7ba3b1db5da686e693dfcde267d7ae5c6e418fe 100644
--- a/bilby/gw/likelihood/roq.py
+++ b/bilby/gw/likelihood/roq.py
@@ -371,8 +371,10 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
         else:
             time_ref = self.parameters['geocent_time']
 
-        size_linear = len(self.waveform_generator.waveform_arguments['frequency_nodes_linear'])
-        size_quadratic = len(self.waveform_generator.waveform_arguments['frequency_nodes_quadratic'])
+        frequency_nodes_linear = self.waveform_generator.waveform_arguments['frequency_nodes_linear']
+        frequency_nodes_quadratic = self.waveform_generator.waveform_arguments['frequency_nodes_quadratic']
+        size_linear = len(frequency_nodes_linear)
+        size_quadratic = len(frequency_nodes_quadratic)
         h_linear = np.zeros(size_linear, dtype=complex)
         h_quadratic = np.zeros(size_quadratic, dtype=complex)
         for mode in waveform_polarizations['linear']:
@@ -385,9 +387,9 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
             h_quadratic += waveform_polarizations['quadratic'][mode] * response
 
         calib_linear = interferometer.calibration_model.get_calibration_factor(
-            size_linear, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
+            frequency_nodes_linear, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
         calib_quadratic = interferometer.calibration_model.get_calibration_factor(
-            size_quadratic, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
+            frequency_nodes_quadratic, prefix='recalib_{}_'.format(interferometer.name), **self.parameters)
 
         h_linear *= calib_linear
         h_quadratic *= calib_quadratic
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 54f803401cac2f64eeaaca46d307eba68d7efc7d..1f605f56f1a85509b09f0fc55e75e491aa7f3a63 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -1183,7 +1183,7 @@ class CalibrationPriorDict(PriorDict):
     @staticmethod
     def constant_uncertainty_spline(
             amplitude_sigma, phase_sigma, minimum_frequency, maximum_frequency,
-            n_nodes, label):
+            n_nodes, label, boundary="reflective"):
         """
         Make prior assuming constant in frequency calibration uncertainty.
 
@@ -1203,6 +1203,8 @@ class CalibrationPriorDict(PriorDict):
             Number of nodes for the spline.
         label: str
             Label for the names of the parameters, e.g., `recalib_H1_`
+        boundary: None, 'reflective', 'periodic'
+            The type of prior boundary to assign
 
         Returns
         =======
@@ -1225,14 +1227,14 @@ class CalibrationPriorDict(PriorDict):
             prior[name] = Gaussian(mu=amplitude_mean_nodes[ii],
                                    sigma=amplitude_sigma_nodes[ii],
                                    name=name, latex_label=latex_label,
-                                   boundary='reflective')
+                                   boundary=boundary)
         for ii in range(n_nodes):
             name = "recalib_{}_phase_{}".format(label, ii)
             latex_label = r"$\phi^{}_{}$".format(label, ii)
             prior[name] = Gaussian(mu=phase_mean_nodes[ii],
                                    sigma=phase_sigma_nodes[ii],
                                    name=name, latex_label=latex_label,
-                                   boundary='reflective')
+                                   boundary=boundary)
         for ii in range(n_nodes):
             name = "recalib_{}_frequency_{}".format(label, ii)
             latex_label = "$f^{}_{}$".format(label, ii)
diff --git a/bilby/gw/prior_files/GW150914.prior b/bilby/gw/prior_files/GW150914.prior
index 2a48c390735e58f9c11ece2429842796aac7d23c..b97de3fa4e5f803f5b8eb0bfb0baeec4dda20575 100644
--- a/bilby/gw/prior_files/GW150914.prior
+++ b/bilby/gw/prior_files/GW150914.prior
@@ -1,5 +1,5 @@
 mass_ratio = bilby.gw.prior.UniformInComponentsMassRatio(name='mass_ratio', minimum=0.125, maximum=1)
-chirp_mass = bilby.gw.prior.UniformInComponentsChirpMass(name='chirp_mass', minimum=25, maximum=31)
+chirp_mass = bilby.gw.prior.UniformInComponentsChirpMass(name='chirp_mass', minimum=25, maximum=35)
 mass_1 = Constraint(name='mass_1', minimum=10, maximum=80)
 mass_2 = Constraint(name='mass_2', minimum=10, maximum=80)
 a_1 = Uniform(name='a_1', minimum=0, maximum=0.99)
diff --git a/bilby/gw/source.py b/bilby/gw/source.py
index 5ea9415e98fb0bb83bc8e3fb214c2d4654d8c549..4d7793b92fe8124e6e0c3f4397df1bfacc3ca91d 100644
--- a/bilby/gw/source.py
+++ b/bilby/gw/source.py
@@ -882,11 +882,10 @@ def sinegaussian(frequency_array, hrss, Q, frequency, **kwargs):
                (np.exp(-fm**2 * np.pi**2 * tau**2) -
                np.exp(-fp**2 * np.pi**2 * tau**2)))
 
-    return{'plus': h_plus, 'cross': h_cross}
+    return {'plus': h_plus, 'cross': h_cross}
 
 
-def supernova(
-        frequency_array, realPCs, imagPCs, file_path, luminosity_distance, **kwargs):
+def supernova(frequency_array, luminosity_distance, **kwargs):
     """
     A source model that reads a simulation from a text file.
 
@@ -897,8 +896,6 @@ def supernova(
     ----------
     frequency_array: array-like
         Unused
-    realPCs: UNUSED
-    imagPCs: UNUSED
     file_path: str
         Path to the file containing the NR simulation. The format of this file
         should be readable by :code:`numpy.loadtxt` and have four columns
@@ -907,7 +904,8 @@ def supernova(
     luminosity_distance: float
         The distance to the source in kpc, this scales the amplitude of the
         signal. The simulation is assumed to be at 10kpc.
-    kwargs: UNUSED
+    kwargs:
+        extra keyword arguments, this should include the :code:`file_path`
 
     Returns
     -------
@@ -915,20 +913,20 @@ def supernova(
         A dictionary containing the plus and cross components of the signal.
     """
 
-    realhplus, imaghplus, realhcross, imaghcross = np.loadtxt(
-        file_path, usecols=(0, 1, 2, 3), unpack=True)
+    file_path = kwargs["file_path"]
+    data = np.genfromtxt(file_path)
 
     # waveform in file at 10kpc
     scaling = 1e-3 * (10.0 / luminosity_distance)
 
-    h_plus = scaling * (realhplus + 1.0j * imaghplus)
-    h_cross = scaling * (realhcross + 1.0j * imaghcross)
+    h_plus = scaling * (data[:, 0] + 1j * data[:, 1])
+    h_cross = scaling * (data[:, 2] + 1j * data[:, 3])
     return {'plus': h_plus, 'cross': h_cross}
 
 
 def supernova_pca_model(
-        frequency_array, pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5,
-        luminosity_distance, **kwargs):
+        frequency_array, pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5, luminosity_distance, **kwargs
+):
     r"""
     Signal model based on a five-component principal component decomposition
     of a model.
@@ -966,22 +964,19 @@ def supernova_pca_model(
         The plus and cross polarizations of the signal
     """
 
-    realPCs = kwargs['realPCs']
-    imagPCs = kwargs['imagPCs']
+    principal_components = kwargs["realPCs"] + 1j * kwargs["imagPCs"]
+    coefficients = [pc_coeff1, pc_coeff2, pc_coeff3, pc_coeff4, pc_coeff5]
 
-    pc1 = realPCs[:, 0] + 1.0j * imagPCs[:, 0]
-    pc2 = realPCs[:, 1] + 1.0j * imagPCs[:, 1]
-    pc3 = realPCs[:, 2] + 1.0j * imagPCs[:, 2]
-    pc4 = realPCs[:, 3] + 1.0j * imagPCs[:, 3]
-    pc5 = realPCs[:, 4] + 1.0j * imagPCs[:, 5]
+    strain = np.sum(
+        [coeff * principal_components[:, ii] for ii, coeff in enumerate(coefficients)],
+        axis=0
+    )
 
     # file at 10kpc
     scaling = 1e-23 * (10.0 / luminosity_distance)
 
-    h_plus = scaling * (pc_coeff1 * pc1 + pc_coeff2 * pc2 + pc_coeff3 * pc3 +
-                        pc_coeff4 * pc4 + pc_coeff5 * pc5)
-    h_cross = scaling * (pc_coeff1 * pc1 + pc_coeff2 * pc2 + pc_coeff3 * pc3 +
-                         pc_coeff4 * pc4 + pc_coeff5 * pc5)
+    h_plus = scaling * strain
+    h_cross = scaling * strain
 
     return {'plus': h_plus, 'cross': h_cross}
 
diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py
index 34211e82b9f31e0a1c9bcb8e2d1ff28c39b269bb..da9956d2c9b9e1d8e24163e1d183077bd61c8d48 100644
--- a/bilby/gw/utils.py
+++ b/bilby/gw/utils.py
@@ -1,14 +1,16 @@
 import json
 import os
 from functools import lru_cache
-from math import fmod
 
 import numpy as np
 from scipy.interpolate import interp1d
 from scipy.special import i0e
+from bilby_cython.geometry import (
+    zenith_azimuth_to_theta_phi as _zenith_azimuth_to_theta_phi,
+)
+from bilby_cython.time import greenwich_mean_sidereal_time
 
-from ..core.utils import (ra_dec_to_theta_phi,
-                          speed_of_light, logger, run_commandline,
+from ..core.utils import (logger, run_commandline,
                           check_directory_exists_and_if_not_mkdir,
                           SamplesSummary, theta_phi_to_ra_dec)
 from ..core.utils.constants import solar_mass
@@ -53,90 +55,6 @@ def psd_from_freq_series(freq_data, df):
     return np.power(asd_from_freq_series(freq_data, df), 2)
 
 
-def time_delay_geocentric(detector1, detector2, ra, dec, time):
-    """
-    Calculate time delay between two detectors in geocentric coordinates based on XLALArrivaTimeDiff in TimeDelay.c
-
-    Parameters
-    ==========
-    detector1: array_like
-        Cartesian coordinate vector for the first detector in the geocentric frame
-        generated by the Interferometer class as self.vertex.
-    detector2: array_like
-        Cartesian coordinate vector for the second detector in the geocentric frame.
-        To get time delay from Earth center, use detector2 = np.array([0,0,0])
-    ra: float
-        Right ascension of the source in radians
-    dec: float
-        Declination of the source in radians
-    time: float
-        GPS time in the geocentric frame
-
-    Returns
-    =======
-    float: Time delay between the two detectors in the geocentric frame
-
-    """
-    gmst = fmod(greenwich_mean_sidereal_time(time), 2 * np.pi)
-    theta, phi = ra_dec_to_theta_phi(ra, dec, gmst)
-    omega = np.array([np.sin(theta) * np.cos(phi), np.sin(theta) * np.sin(phi), np.cos(theta)])
-    delta_d = detector2 - detector1
-    return np.dot(omega, delta_d) / speed_of_light
-
-
-def get_polarization_tensor(ra, dec, time, psi, mode):
-    """
-    Calculate the polarization tensor for a given sky location and time
-
-    See Nishizawa et al. (2009) arXiv:0903.0528 for definitions of the polarisation tensors.
-    [u, v, w] represent the Earth-frame
-    [m, n, omega] represent the wave-frame
-    Note: there is a typo in the definition of the wave-frame in Nishizawa et al.
-
-    Parameters
-    ==========
-    ra: float
-        right ascension in radians
-    dec: float
-        declination in radians
-    time: float
-        geocentric GPS time
-    psi: float
-        binary polarisation angle counter-clockwise about the direction of propagation
-    mode: str
-        polarisation mode
-
-    Returns
-    =======
-    array_like: A 3x3 representation of the polarization_tensor for the specified mode.
-
-    """
-    gmst = fmod(greenwich_mean_sidereal_time(time), 2 * np.pi)
-    theta, phi = ra_dec_to_theta_phi(ra, dec, gmst)
-    u = np.array([np.cos(phi) * np.cos(theta), np.cos(theta) * np.sin(phi), -np.sin(theta)])
-    v = np.array([-np.sin(phi), np.cos(phi), 0])
-    m = -u * np.sin(psi) - v * np.cos(psi)
-    n = -u * np.cos(psi) + v * np.sin(psi)
-
-    if mode.lower() == 'plus':
-        return np.einsum('i,j->ij', m, m) - np.einsum('i,j->ij', n, n)
-    elif mode.lower() == 'cross':
-        return np.einsum('i,j->ij', m, n) + np.einsum('i,j->ij', n, m)
-    elif mode.lower() == 'breathing':
-        return np.einsum('i,j->ij', m, m) + np.einsum('i,j->ij', n, n)
-
-    # Calculating omega here to avoid calculation when model in [plus, cross, breathing]
-    omega = np.cross(m, n)
-    if mode.lower() == 'longitudinal':
-        return np.einsum('i,j->ij', omega, omega)
-    elif mode.lower() == 'x':
-        return np.einsum('i,j->ij', m, omega) + np.einsum('i,j->ij', omega, m)
-    elif mode.lower() == 'y':
-        return np.einsum('i,j->ij', n, omega) + np.einsum('i,j->ij', omega, n)
-    else:
-        raise ValueError("{} not a polarization mode!".format(mode))
-
-
 def get_vertex_position_geocentric(latitude, longitude, elevation):
     """
     Calculate the position of the IFO vertex in geocentric coordinates in meters.
@@ -310,56 +228,6 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non
     return sum(integral).real
 
 
-__cached_euler_matrix = None
-__cached_delta_x = None
-
-
-def euler_rotation(delta_x):
-    """
-    Calculate the rotation matrix mapping the vector (0, 0, 1) to delta_x
-    while preserving the origin of the azimuthal angle.
-
-    This is decomposed into three Euler angle, alpha, beta, gamma, which rotate
-    about the z-, y-, and z- axes respectively.
-
-    Parameters
-    ==========
-    delta_x: array-like (3,)
-        Vector onto which (0, 0, 1) should be mapped.
-
-    Returns
-    =======
-    total_rotation: array-like (3,3)
-        Rotation matrix which maps vectors from the frame in which delta_x is
-        aligned with the z-axis to the target frame.
-    """
-    global __cached_delta_x
-    global __cached_euler_matrix
-
-    delta_x = delta_x / np.sum(delta_x**2)**0.5
-    if np.array_equal(delta_x, __cached_delta_x):
-        return __cached_euler_matrix
-    else:
-        __cached_delta_x = delta_x
-    alpha = np.arctan(- delta_x[1] * delta_x[2] / delta_x[0])
-    beta = np.arccos(delta_x[2])
-    gamma = np.arctan(delta_x[1] / delta_x[0])
-    rotation_1 = np.array([
-        [np.cos(alpha), -np.sin(alpha), 0], [np.sin(alpha), np.cos(alpha), 0],
-        [0, 0, 1]])
-    rotation_2 = np.array([
-        [np.cos(beta), 0, - np.sin(beta)], [0, 1, 0],
-        [np.sin(beta), 0, np.cos(beta)]])
-    rotation_3 = np.array([
-        [np.cos(gamma), -np.sin(gamma), 0], [np.sin(gamma), np.cos(gamma), 0],
-        [0, 0, 1]])
-    total_rotation = np.einsum(
-        'ij,jk,kl->il', rotation_3, rotation_2, rotation_1)
-    __cached_delta_x = delta_x
-    __cached_euler_matrix = total_rotation
-    return total_rotation
-
-
 def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos):
     """
     Convert from the 'detector frame' to the Earth frame.
@@ -379,15 +247,7 @@ def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos):
         The zenith and azimuthal angles in the earth frame.
     """
     delta_x = ifos[0].geometry.vertex - ifos[1].geometry.vertex
-    omega_prime = np.array([
-        np.sin(zenith) * np.cos(azimuth),
-        np.sin(zenith) * np.sin(azimuth),
-        np.cos(zenith)])
-    rotation_matrix = euler_rotation(delta_x)
-    omega = np.dot(rotation_matrix, omega_prime)
-    theta = np.arccos(omega[2])
-    phi = np.arctan2(omega[1], omega[0]) % (2 * np.pi)
-    return theta, phi
+    return _zenith_azimuth_to_theta_phi(zenith, azimuth, delta_x)
 
 
 def zenith_azimuth_to_ra_dec(zenith, azimuth, geocent_time, ifos):
@@ -1005,27 +865,6 @@ def plot_spline_pos(log_freqs, samples, nfreqs=100, level=0.9, color='k', label=
     plt.xlim(freq_points.min() - .5, freq_points.max() + 50)
 
 
-def greenwich_mean_sidereal_time(time):
-    """
-    Compute the greenwich mean sidereal time from a GPS time.
-
-    This is just a wrapper around :code:`lal.GreenwichMeanSiderealTime` .
-
-    Parameters
-    ----------
-    time: float
-        The GPS time to convert.
-
-    Returns
-    -------
-    float
-        The sidereal time.
-    """
-    from lal import GreenwichMeanSiderealTime
-    time = float(time)
-    return GreenwichMeanSiderealTime(time)
-
-
 def ln_i0(value):
     """
     A numerically stable method to evaluate ln(I_0) a modified Bessel function
diff --git a/bilby/gw/waveform_generator.py b/bilby/gw/waveform_generator.py
index f4cd5a6feeea9cb0804dbf1750d33cfe65addb12..f85e42a964c2f1e079eb80e7b112911bbfa041c3 100644
--- a/bilby/gw/waveform_generator.py
+++ b/bilby/gw/waveform_generator.py
@@ -4,6 +4,7 @@ from ..core import utils
 from ..core.series import CoupledTimeAndFrequencySeries
 from ..core.utils import PropertyAccessor
 from .conversion import convert_to_lal_binary_black_hole_parameters
+from .utils import lalsim_GetApproximantFromString
 
 
 class WaveformGenerator(object):
@@ -253,3 +254,20 @@ class WaveformGenerator(object):
             raise AttributeError('Either time or frequency domain source '
                                  'model must be provided.')
         return set(utils.infer_parameters_from_function(model))
+
+
+class LALCBCWaveformGenerator(WaveformGenerator):
+    """ A waveform generator with specific checks for LAL CBC waveforms """
+    LAL_SIM_INSPIRAL_SPINS_FLOW = 1
+
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+        self.validate_reference_frequency()
+
+    def validate_reference_frequency(self):
+        from lalsimulation import SimInspiralGetSpinFreqFromApproximant
+        waveform_approximant = self.waveform_arguments["waveform_approximant"]
+        waveform_approximant_number = lalsim_GetApproximantFromString(waveform_approximant)
+        if SimInspiralGetSpinFreqFromApproximant(waveform_approximant_number) == self.LAL_SIM_INSPIRAL_SPINS_FLOW:
+            if self.waveform_arguments["reference_frequency"] != self.waveform_arguments["minimum_frequency"]:
+                raise ValueError(f"For {waveform_approximant}, reference_frequency must equal minimum_frequency")
diff --git a/containers/dockerfile-template b/containers/dockerfile-template
index 0d90276e1fc878ad1d4d6b6669f10ba3c5181801..1f5532cda5ad1e6238690c9b1b2074ece501826e 100644
--- a/containers/dockerfile-template
+++ b/containers/dockerfile-template
@@ -44,10 +44,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${{conda_env}} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
@@ -60,5 +56,5 @@ RUN mkdir roq_basis \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_linear.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_quadratic.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/params.dat \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis.hdf5 \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband.hdf5
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_addcal.hdf5 \
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband_addcal.hdf5
diff --git a/containers/v3-dockerfile-test-suite-python310 b/containers/v3-dockerfile-test-suite-python310
index f4c888d067de9a8dfd5e1d0581f88df5dc3ee229..96d8324bdd2d1e3ae4b38bae7d4d1f39be66b371 100644
--- a/containers/v3-dockerfile-test-suite-python310
+++ b/containers/v3-dockerfile-test-suite-python310
@@ -46,10 +46,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
@@ -62,5 +58,5 @@ RUN mkdir roq_basis \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_linear.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_quadratic.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/params.dat \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis.hdf5 \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband.hdf5
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_addcal.hdf5 \
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband_addcal.hdf5
diff --git a/containers/v3-dockerfile-test-suite-python38 b/containers/v3-dockerfile-test-suite-python38
index a45a4ab5be24de1271afdaa8e52cfc48438bfe4a..4c3df21441d80b8f8ffbbc2c581235df577ba619 100644
--- a/containers/v3-dockerfile-test-suite-python38
+++ b/containers/v3-dockerfile-test-suite-python38
@@ -46,10 +46,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
@@ -62,5 +58,5 @@ RUN mkdir roq_basis \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_linear.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_quadratic.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/params.dat \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis.hdf5 \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband.hdf5
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_addcal.hdf5 \
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband_addcal.hdf5
diff --git a/containers/v3-dockerfile-test-suite-python39 b/containers/v3-dockerfile-test-suite-python39
index a15c0e188063889495b1c7ccf3c08a1bc19860af..af4da1c1a8ce2009c7ccb5c22ad2f7950b903def 100644
--- a/containers/v3-dockerfile-test-suite-python39
+++ b/containers/v3-dockerfile-test-suite-python39
@@ -46,10 +46,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
@@ -62,5 +58,5 @@ RUN mkdir roq_basis \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_linear.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/fnodes_quadratic.npy \
     && wget https://git.ligo.org/lscsoft/ROQ_data/raw/master/IMRPhenomPv2/4s/params.dat \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis.hdf5 \
-    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband.hdf5
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_addcal.hdf5 \
+    && wget https://git.ligo.org/soichiro.morisaki/roq_basis/raw/main/IMRPhenomD/16s_nospins/basis_multiband_addcal.hdf5
diff --git a/docs/likelihood.txt b/docs/likelihood.txt
index 46c38102bc83d2e06d8476eb4b5a5502fa181898..79f917ff4f0fd8f17ec13802f40332d979087c65 100644
--- a/docs/likelihood.txt
+++ b/docs/likelihood.txt
@@ -128,9 +128,7 @@ In :code:`bilby`, we can code this up as a likelihood in the following way::
            self.function = function
 
            # These lines of code infer the parameters from the provided function
-           parameters = inspect.getargspec(function).args
-           parameters.pop(0)
-           super().__init__(parameters=dict.fromkeys(parameters))
+           super().__init__(parameters=dict())
 
 
        def log_likelihood(self):
@@ -197,10 +195,10 @@ instantiating the likelihood::
           self.function = function
 
           # These lines of code infer the parameters from the provided function
-          parameters = inspect.getargspec(function).args
-          parameters.pop(0)
+          parameters = inspect.getfullargspec(function).args
+          del parameters[0]
           super().__init__(parameters=dict.fromkeys(parameters))
-                    self.parameters = dict.fromkeys(parameters)
+          self.parameters = dict.fromkeys(parameters)
 
           self.function_keys = self.parameters.keys()
           if self.sigma is None:
diff --git a/docs/samplers.txt b/docs/samplers.txt
index fad1d62ba2953e46bce4e635c98269c0e8aeecb1..69699dad0e3b22725b4d3f3f6d942763406df0da 100644
--- a/docs/samplers.txt
+++ b/docs/samplers.txt
@@ -69,7 +69,7 @@ MCMC samplers
 - bilby-mcmc :code:`bilby.bilby_mcmc.sampler.Bilby_MCMC`
 - emcee :code:`bilby.core.sampler.emcee.Emcee`
 - ptemcee :code:`bilby.core.sampler.ptemcee.Ptemcee`
-- pymc3 :code:`bilby.core.sampler.pymc3.Pymc3`
+- pymc :code:`bilby.core.sampler.pymc.Pymc`
 - zeus :code:`bilby.core.sampler.zeus.Zeus`
 
 
diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc3.py b/examples/core_examples/alternative_samplers/linear_regression_pymc.py
similarity index 97%
rename from examples/core_examples/alternative_samplers/linear_regression_pymc3.py
rename to examples/core_examples/alternative_samplers/linear_regression_pymc.py
index 75cbf16aed72a7c8e572df342a5817dc092499c7..0efc872beadb35acde686846681c10890a4b2dae 100644
--- a/examples/core_examples/alternative_samplers/linear_regression_pymc3.py
+++ b/examples/core_examples/alternative_samplers/linear_regression_pymc.py
@@ -11,7 +11,7 @@ import numpy as np
 from bilby.core.likelihood import GaussianLikelihood
 
 # A few simple setup steps
-label = "linear_regression_pymc3"
+label = "linear_regression_pymc"
 outdir = "outdir"
 bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
 
@@ -58,7 +58,7 @@ priors["c"] = bilby.core.prior.Uniform(-2, 2, "c")
 result = bilby.run_sampler(
     likelihood=likelihood,
     priors=priors,
-    sampler="pymc3",
+    sampler="pymc",
     injection_parameters=injection_parameters,
     outdir=outdir,
     draws=2000,
diff --git a/examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
similarity index 77%
rename from examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py
rename to examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
index d2074304f83064167facf1166e7ec582f8056fb7..e9763770c8ad96b32adbf256cc44aabf0ceffb1e 100644
--- a/examples/core_examples/alternative_samplers/linear_regression_pymc3_custom_likelihood.py
+++ b/examples/core_examples/alternative_samplers/linear_regression_pymc_custom_likelihood.py
@@ -11,10 +11,10 @@ would give equivalent results as using the pre-defined 'Gaussian Likelihood'
 import bilby
 import matplotlib.pyplot as plt
 import numpy as np
-import pymc3 as pm
+import pymc as pm
 
 # A few simple setup steps
-label = "linear_regression_pymc3_custom_likelihood"
+label = "linear_regression_pymc_custom_likelihood"
 outdir = "outdir"
 bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
 
@@ -50,7 +50,7 @@ fig.savefig("{}/{}_data.png".format(outdir, label))
 
 # Parameter estimation: we now define a Gaussian Likelihood class relevant for
 # our model.
-class GaussianLikelihoodPyMC3(bilby.core.likelihood.GaussianLikelihood):
+class GaussianLikelihoodPyMC(bilby.core.likelihood.GaussianLikelihood):
     def __init__(self, x, y, sigma, func):
         """
         A general Gaussian likelihood - the parameters are inferred from the
@@ -68,45 +68,44 @@ class GaussianLikelihoodPyMC3(bilby.core.likelihood.GaussianLikelihood):
             will require a prior and will be sampled over (unless a fixed
             value is given).
         """
-        super(GaussianLikelihoodPyMC3, self).__init__(x=x, y=y, func=func, sigma=sigma)
+        super(GaussianLikelihoodPyMC, self).__init__(x=x, y=y, func=func, sigma=sigma)
 
     def log_likelihood(self, sampler=None):
         """
         Parameters
         ----------
-        sampler: :class:`bilby.core.sampler.Pymc3`
+        sampler: :class:`bilby.core.sampler.Pymc`
             A Sampler object must be passed containing the prior distributions
             and PyMC3 :class:`~pymc3.Model` to use as a context manager.
             If this is not passed, the super class is called and the regular
             likelihood is evaluated.
         """
 
-        from bilby.core.sampler import Pymc3
+        from bilby.core.sampler import Pymc
 
-        if not isinstance(sampler, Pymc3):
-            print(sampler, type(sampler))
-            return super(GaussianLikelihoodPyMC3, self).log_likelihood()
+        if not isinstance(sampler, Pymc):
+            return super(GaussianLikelihoodPyMC, self).log_likelihood()
 
-        if not hasattr(sampler, "pymc3_model"):
-            raise AttributeError("Sampler has not PyMC3 model attribute")
+        if not hasattr(sampler, "pymc_model"):
+            raise AttributeError("Sampler has not PyMC model attribute")
 
-        with sampler.pymc3_model:
-            mdist = sampler.pymc3_priors["m"]
-            cdist = sampler.pymc3_priors["c"]
+        with sampler.pymc_model:
+            mdist = sampler.pymc_priors["m"]
+            cdist = sampler.pymc_priors["c"]
 
             mu = model(time, mdist, cdist)
 
             # set the likelihood distribution
-            pm.Normal("likelihood", mu=mu, sd=self.sigma, observed=self.y)
+            pm.Normal("likelihood", mu=mu, sigma=self.sigma, observed=self.y)
 
 
 # Now lets instantiate a version of our GaussianLikelihood, giving it
 # the time, data and signal model
-likelihood = GaussianLikelihoodPyMC3(time, data, sigma, model)
+likelihood = GaussianLikelihoodPyMC(time, data, sigma, model)
 
 
-# Define a custom prior for one of the parameter for use with PyMC3
-class PyMC3UniformPrior(bilby.core.prior.Uniform):
+# Define a custom prior for one of the parameter for use with PyMC
+class PyMCUniformPrior(bilby.core.prior.Uniform):
     def __init__(self, minimum, maximum, name=None, latex_label=None):
         """
         Uniform prior with bounds (should be equivalent to bilby.prior.Uniform)
@@ -124,10 +123,10 @@ class PyMC3UniformPrior(bilby.core.prior.Uniform):
         float or array to be passed to the superclass.
         """
 
-        from bilby.core.sampler import Pymc3
+        from bilby.core.sampler import Pymc
 
-        if not isinstance(sampler, Pymc3):
-            return super(PyMC3UniformPrior, self).ln_prob(sampler)
+        if not isinstance(sampler, Pymc):
+            return super(PyMCUniformPrior, self).ln_prob(sampler)
 
         return pm.Uniform(self.name, lower=self.minimum, upper=self.maximum)
 
@@ -136,13 +135,13 @@ class PyMC3UniformPrior(bilby.core.prior.Uniform):
 # We make a prior
 priors = dict()
 priors["m"] = bilby.core.prior.Uniform(0, 5, "m")
-priors["c"] = PyMC3UniformPrior(-2, 2, "c")
+priors["c"] = PyMCUniformPrior(-2, 2, "c")
 
 # And run sampler
 result = bilby.run_sampler(
     likelihood=likelihood,
     priors=priors,
-    sampler="pymc3",
+    sampler="pymc",
     draws=1000,
     tune=1000,
     discard_tuned_samples=True,
diff --git a/examples/core_examples/gaussian_process_celerite_example.py b/examples/core_examples/gaussian_process_celerite_example.py
index 417493f57745ba211b333d70094d731bed0291cf..574566a9af00c4208667f62b11572ff3892afc76 100644
--- a/examples/core_examples/gaussian_process_celerite_example.py
+++ b/examples/core_examples/gaussian_process_celerite_example.py
@@ -1,13 +1,11 @@
-import matplotlib.pyplot as plt
-import numpy as np
 from pathlib import Path
 
-import celerite.terms
-
 import bilby
+import celerite.terms
+import matplotlib.pyplot as plt
+import numpy as np
 from bilby.core.prior import Uniform
 
-
 # In this example we show how we can use the `celerite` package within `bilby`.
 # We begin by synthesizing some data and then use a simple Gaussian Process
 # model to fit and interpolate the data.
diff --git a/examples/core_examples/gaussian_process_george_example.py b/examples/core_examples/gaussian_process_george_example.py
index 2c2946583cbe7ab32aa4ddd106e1b2f44559e76e..969a56a72f5182962a8f7087041b57e791c48248 100644
--- a/examples/core_examples/gaussian_process_george_example.py
+++ b/examples/core_examples/gaussian_process_george_example.py
@@ -1,13 +1,11 @@
-import matplotlib.pyplot as plt
-import numpy as np
 from pathlib import Path
 
-import george
-
 import bilby
+import george
+import matplotlib.pyplot as plt
+import numpy as np
 from bilby.core.prior import Uniform
 
-
 # In this example we show how we can use the `george` package within
 # `bilby`. We begin by synthesizing some data and then use a simple Gaussian
 # Process model to fit and interpolate the data. `bilby` implements a
diff --git a/examples/core_examples/hyper_parameter_example.py b/examples/core_examples/hyper_parameter_example.py
index 638d47edabaa280881bfc04703f5504402cbb5f8..288e9c21b07aa254da6725f57d95b7e3eb05fa99 100644
--- a/examples/core_examples/hyper_parameter_example.py
+++ b/examples/core_examples/hyper_parameter_example.py
@@ -4,7 +4,6 @@ An example of how to use bilby to perform parameter estimation for hyper params
 """
 import matplotlib.pyplot as plt
 import numpy as np
-
 from bilby.core.likelihood import GaussianLikelihood
 from bilby.core.prior import Uniform
 from bilby.core.result import make_pp_plot
diff --git a/examples/core_examples/slabspike_example.py b/examples/core_examples/slabspike_example.py
index 21df0790ff55d88a183686ec76c0f5c96b08d3d2..2c42c174bda64e8535017cfce190850e3e9faad5 100644
--- a/examples/core_examples/slabspike_example.py
+++ b/examples/core_examples/slabspike_example.py
@@ -13,11 +13,10 @@ To install `PyMultiNest` call
 $ conda install -c conda-forge pymultinest
 """
 
+import bilby
 import matplotlib.pyplot as plt
 import numpy as np
 
-import bilby
-
 outdir = "outdir"
 label = "slabspike"
 bilby.utils.check_directory_exists_and_if_not_mkdir(outdir)
diff --git a/examples/gw_examples/injection_examples/australian_detector.py b/examples/gw_examples/injection_examples/australian_detector.py
index 6b903a9fc0695a279ef777e0d0fa166aa7fae140..85c5cf03e7cb640207e27e731b95c1b3acae7820 100644
--- a/examples/gw_examples/injection_examples/australian_detector.py
+++ b/examples/gw_examples/injection_examples/australian_detector.py
@@ -3,52 +3,54 @@
 Tutorial to demonstrate a new interferometer
 
 We place a new instrument in Gingin, with an A+ sensitivity in a network of A+
-interferometers at Hanford and Livingston
-"""
+interferometers at Hanford and Livingston.
 
-import numpy as np
+This requires :code:`gwinc` to be installed. This is available via conda-forge.
+"""
 
 import bilby
 import gwinc
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment that we're going
 # to inject the signal into
-duration = 4.
-sampling_frequency = 2048.
+duration = 4
+sampling_frequency = 1024
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'australian_detector'
+outdir = "outdir"
+label = "australian_detector"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
 np.random.seed(88170232)
 
 # create a new detector using a PyGwinc sensitivity curve
-frequencies = np.logspace(0, 3, 1000)
-budget, gwinc_ifo, _, _ = gwinc.load_ifo('Aplus')
-gwinc_ifo = gwinc.precompIFO(frequencies, gwinc_ifo)
-gwinc_traces = budget(frequencies, ifo=gwinc_ifo).calc_trace()
-gwinc_noises = {n: d[0] for n, d in gwinc_traces.items()}
-
-Aplus_psd = gwinc_noises['Total']
+curve = gwinc.load_budget("Aplus").run()
 
 # Set up the detector as a four-kilometer detector in Gingin
 # The location of this detector is not defined in Bilby, so we need to add it
 AusIFO = bilby.gw.detector.Interferometer(
     power_spectral_density=bilby.gw.detector.PowerSpectralDensity(
-        frequency_array=frequencies, psd_array=Aplus_psd),
-    name='AusIFO', length=4,
-    minimum_frequency=min(frequencies), maximum_frequency=max(frequencies),
-    latitude=-31.34, longitude=115.91,
-    elevation=0., xarm_azimuth=2., yarm_azimuth=125.)
+        frequency_array=curve.freq, psd_array=curve.psd
+    ),
+    name="AusIFO",
+    length=4,
+    minimum_frequency=20,
+    maximum_frequency=sampling_frequency / 2,
+    latitude=-31.34,
+    longitude=115.91,
+    elevation=0.0,
+    xarm_azimuth=2.0,
+    yarm_azimuth=125.0,
+)
 
 # Set up two other detectors at Hanford and Livingston
-interferometers = bilby.gw.detector.InterferometerList(['H1', 'L1'])
-for interferometer in interferometers:
-    interferometer.power_spectral_density =\
-        bilby.gw.detector.PowerSpectralDensity(
-            frequency_array=frequencies, psd_array=Aplus_psd)
+interferometers = bilby.gw.detector.InterferometerList(["H1", "L1"])
+for ifo in interferometers:
+    ifo.power_spectral_density = bilby.gw.detector.PowerSpectralDensity(
+        frequency_array=curve.freq, psd_array=curve.psd
+    )
 
 # append the Australian detector to the list of other detectors
 interferometers.append(AusIFO)
@@ -58,53 +60,88 @@ interferometers.append(AusIFO)
 # as we're using a three-detector network of A+, we inject a GW150914-like
 # signal at 4 Gpc
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=4000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=0.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=4000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=0.2108,
+)
 
 
 # Fixed arguments passed into the source model
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50.0)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
-start_time = injection_parameters['geocent_time'] + 2 - duration
+start_time = injection_parameters["geocent_time"] + 2 - duration
 
 # inject the signal into the interferometers
 
-for interferometer in interferometers:
-    interferometer.set_strain_data_from_power_spectral_density(
-        sampling_frequency=sampling_frequency, duration=duration)
-    interferometer.inject_signal(
-        parameters=injection_parameters, waveform_generator=waveform_generator)
+for ifo in interferometers:
+    ifo.set_strain_data_from_power_spectral_density(
+        sampling_frequency=sampling_frequency, duration=duration
+    )
+    ifo.inject_signal(
+        parameters=injection_parameters, waveform_generator=waveform_generator
+    )
 
     # plot the data for sanity
-    signal = interferometer.get_detector_response(
-        waveform_generator.frequency_domain_strain(), injection_parameters)
-    interferometer.plot_data(signal=signal, outdir=outdir, label=label)
+    signal = ifo.get_detector_response(
+        waveform_generator.frequency_domain_strain(), injection_parameters
+    )
+    ifo.plot_data(signal=signal, outdir=outdir, label=label)
 
 # set up priors
 priors = bilby.gw.prior.BBHPriorDict()
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi',
-            'geocent_time', 'phase']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "psi",
+    "geocent_time",
+    "phase",
+    "ra",
+    "dec",
+    "luminosity_distance",
+    "theta_jn",
+]:
     priors[key] = injection_parameters[key]
 
 # Initialise the likelihood by passing in the interferometer data (IFOs)
 # and the waveoform generator
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=interferometers, waveform_generator=waveform_generator,
-    time_marginalization=False, phase_marginalization=False,
-    distance_marginalization=False, priors=priors)
+    interferometers=interferometers,
+    waveform_generator=waveform_generator,
+)
 
 
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    npoints=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # make some plots of the outputs
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/binary_neutron_star_example.py b/examples/gw_examples/injection_examples/binary_neutron_star_example.py
index e6aa18a9a47c9594f05a446d47d6294a31fc6032..90308c3c5b3d2bc973335eb1658647af685a6380 100644
--- a/examples/gw_examples/injection_examples/binary_neutron_star_example.py
+++ b/examples/gw_examples/injection_examples/binary_neutron_star_example.py
@@ -9,13 +9,12 @@ tidal deformabilities
 """
 
 
-import numpy as np
-
 import bilby
+import numpy as np
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'bns_example'
+outdir = "outdir"
+label = "bns_example"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -26,39 +25,56 @@ np.random.seed(88170235)
 # parameters, including masses of the two black holes (mass_1, mass_2),
 # aligned spins of both black holes (chi_1, chi_2), etc.
 injection_parameters = dict(
-    mass_1=1.5, mass_2=1.3, chi_1=0.02, chi_2=0.02, luminosity_distance=50.,
-    theta_jn=0.4, psi=2.659, phase=1.3, geocent_time=1126259642.413,
-    ra=1.375, dec=-1.2108, lambda_1=400, lambda_2=450)
+    mass_1=1.5,
+    mass_2=1.3,
+    chi_1=0.02,
+    chi_2=0.02,
+    luminosity_distance=50.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+    lambda_1=400,
+    lambda_2=450,
+)
 
 # Set the duration and sampling frequency of the data segment that we're going
 # to inject the signal into. For the
 # TaylorF2 waveform, we cut the signal close to the isco frequency
 duration = 32
-sampling_frequency = 2 * 1024
-start_time = injection_parameters['geocent_time'] + 2 - duration
+sampling_frequency = 2048
+start_time = injection_parameters["geocent_time"] + 2 - duration
 
 # Fixed arguments passed into the source model. The analysis starts at 40 Hz.
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2_NRTidal',
-                          reference_frequency=50., minimum_frequency=40.0)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomPv2_NRTidal",
+    reference_frequency=50.0,
+    minimum_frequency=40.0,
+)
 
 # Create the waveform_generator using a LAL Binary Neutron Star source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_neutron_star,
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_neutron_star_parameters,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.  In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).
 # These default to their design sensitivity and start at 40 Hz.
-interferometers = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+interferometers = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 for interferometer in interferometers:
     interferometer.minimum_frequency = 40
 interferometers.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=start_time)
-interferometers.inject_signal(parameters=injection_parameters,
-                              waveform_generator=waveform_generator)
+    sampling_frequency=sampling_frequency, duration=duration, start_time=start_time
+)
+interferometers.inject_signal(
+    parameters=injection_parameters, waveform_generator=waveform_generator
+)
 
 # Load the default prior for binary neutron stars.
 # We're going to sample in chirp_mass, symmetric_mass_ratio, lambda_tilde, and
@@ -66,31 +82,45 @@ interferometers.inject_signal(parameters=injection_parameters,
 # BNS have aligned spins by default, if you want to allow precessing spins
 # pass aligned_spin=False to the BNSPriorDict
 priors = bilby.gw.prior.BNSPriorDict()
-for key in ['psi', 'geocent_time', 'ra', 'dec', 'chi_1', 'chi_2',
-            'theta_jn', 'luminosity_distance', 'phase']:
+for key in [
+    "psi",
+    "geocent_time",
+    "ra",
+    "dec",
+    "chi_1",
+    "chi_2",
+    "theta_jn",
+    "luminosity_distance",
+    "phase",
+]:
     priors[key] = injection_parameters[key]
-priors.pop('mass_ratio')
-priors.pop('lambda_1')
-priors.pop('lambda_2')
-priors['chirp_mass'] = bilby.core.prior.Gaussian(
-    1.215, 0.1, name='chirp_mass', unit='$M_{\\odot}$')
-priors['symmetric_mass_ratio'] = bilby.core.prior.Uniform(
-    0.1, 0.25, name='symmetric_mass_ratio')
-priors['lambda_tilde'] = bilby.core.prior.Uniform(0, 5000, name='lambda_tilde')
-priors['delta_lambda'] = bilby.core.prior.Uniform(
-    -5000, 5000, name='delta_lambda')
+del priors["mass_ratio"], priors["lambda_1"], priors["lambda_2"]
+priors["chirp_mass"] = bilby.core.prior.Gaussian(
+    1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$"
+)
+priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform(
+    0.1, 0.25, name="symmetric_mass_ratio"
+)
+priors["lambda_tilde"] = bilby.core.prior.Uniform(0, 5000, name="lambda_tilde")
+priors["delta_lambda"] = bilby.core.prior.Uniform(-5000, 5000, name="delta_lambda")
 
 # Initialise the likelihood by passing in the interferometer data (IFOs)
 # and the waveform generator
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=interferometers, waveform_generator=waveform_generator,
-    time_marginalization=False, phase_marginalization=False,
-    distance_marginalization=False, priors=priors)
+    interferometers=interferometers,
+    waveform_generator=waveform_generator,
+)
 
 # Run sampler.  In this case we're going to use the `nestle` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='nestle', npoints=100,
-    injection_parameters=injection_parameters, outdir=outdir, label=label,
-    conversion_function=bilby.gw.conversion.generate_all_bns_parameters)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="nestle",
+    npoints=100,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+    conversion_function=bilby.gw.conversion.generate_all_bns_parameters,
+)
 
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/bns_eos_example.py b/examples/gw_examples/injection_examples/bns_eos_example.py
index 5e32138b94bf6b6f402f7618fc07a0d12bcc3c61..881cc5ae8f09f02e2d2c2b705023a5fe5ca4207b 100644
--- a/examples/gw_examples/injection_examples/bns_eos_example.py
+++ b/examples/gw_examples/injection_examples/bns_eos_example.py
@@ -1,22 +1,20 @@
 #!/usr/bin/env python
 """
 Tutorial to demonstrate running parameter estimation on a binary neutron star
-system taking into account tidal deformabilities.
+system taking into account tidal deformabilities with a physically motivated
+model for the tidal deformabilities.
 
-This example estimates the masses using a uniform prior in both component masses
-and also estimates the tidal deformabilities using a uniform prior in both
-tidal deformabilities
+WARNING: The code is extremely slow.
 """
 
 
-import numpy as np
-
 import bilby
-from bilby.gw.eos import TabularEOS, EOSFamily
+import numpy as np
+from bilby.gw.eos import EOSFamily, TabularEOS
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'bns_eos_example'
+outdir = "outdir"
+label = "bns_eos_example"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -32,7 +30,7 @@ np.random.seed(88170235)
 # assuming a specific equation of state, and calculating
 # corresponding tidal deformability parameters from the EoS and
 # masses.
-mpa1_eos = TabularEOS('MPA1')
+mpa1_eos = TabularEOS("MPA1")
 mpa1_fam = EOSFamily(mpa1_eos)
 
 mass_1 = 1.5
@@ -42,41 +40,58 @@ lambda_2 = mpa1_fam.lambda_from_mass(mass_2)
 
 
 injection_parameters = dict(
-    mass_1=mass_1, mass_2=mass_2, chi_1=0.02, chi_2=0.02, luminosity_distance=50.,
-    theta_jn=0.4, psi=2.659, phase=1.3, geocent_time=1126259642.413,
-    ra=1.375, dec=-1.2108, lambda_1=lambda_1, lambda_2=lambda_2)
+    mass_1=mass_1,
+    mass_2=mass_2,
+    chi_1=0.02,
+    chi_2=0.02,
+    luminosity_distance=50.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+    lambda_1=lambda_1,
+    lambda_2=lambda_2,
+)
 
 # Set the duration and sampling frequency of the data segment that we're going
 # to inject the signal into. For the
 # TaylorF2 waveform, we cut the signal close to the isco frequency
 duration = 32
-sampling_frequency = 2 * 1024
-start_time = injection_parameters['geocent_time'] + 2 - duration
+sampling_frequency = 2048
+start_time = injection_parameters["geocent_time"] + 2 - duration
 
 # Fixed arguments passed into the source model. The analysis starts at 40 Hz.
 # Note that the EoS sampling is agnostic to waveform model as long as the approximant
 # can include tides.
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2_NRTidal',
-                          reference_frequency=50., minimum_frequency=40.0)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomPv2_NRTidal",
+    reference_frequency=50.0,
+    minimum_frequency=40.0,
+)
 
 # Create the waveform_generator using a LAL Binary Neutron Star source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_neutron_star,
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_neutron_star_parameters,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.  In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).
 # These default to their design sensitivity and start at 40 Hz.
-interferometers = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+interferometers = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 for interferometer in interferometers:
     interferometer.minimum_frequency = 40
 interferometers.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=start_time)
-interferometers.inject_signal(parameters=injection_parameters,
-                              waveform_generator=waveform_generator)
+    sampling_frequency=sampling_frequency, duration=duration, start_time=start_time
+)
+interferometers.inject_signal(
+    parameters=injection_parameters, waveform_generator=waveform_generator
+)
 
 # We're going to sample in chirp_mass, symmetric_mass_ratio, and
 # specific EoS model parameters. We're using a 4-parameter
@@ -84,38 +99,62 @@ interferometers.inject_signal(parameters=injection_parameters,
 # BNS have aligned spins by default, if you want to allow precessing spins
 # pass aligned_spin=False to the BNSPriorDict
 priors = bilby.gw.prior.BNSPriorDict()
-for key in ['psi', 'geocent_time', 'ra', 'dec', 'chi_1', 'chi_2',
-            'theta_jn', 'luminosity_distance', 'phase']:
+for key in [
+    "psi",
+    "geocent_time",
+    "ra",
+    "dec",
+    "chi_1",
+    "chi_2",
+    "theta_jn",
+    "luminosity_distance",
+    "phase",
+]:
     priors[key] = injection_parameters[key]
-priors.pop('mass_1')
-priors.pop('mass_2')
-priors.pop('lambda_1')
-priors.pop('lambda_2')
-priors.pop('mass_ratio')
-priors['chirp_mass'] = bilby.core.prior.Gaussian(1.215, 0.1, name='chirp_mass', unit='$M_{\\odot}$')
-priors['symmetric_mass_ratio'] = bilby.core.prior.Uniform(0.1, 0.25, name='symmetric_mass_ratio')
-priors['eos_spectral_gamma_0'] = bilby.core.prior.Uniform(0.2, 2.0, name='gamma0', latex_label='$\\gamma_0')
-priors['eos_spectral_gamma_1'] = bilby.core.prior.Uniform(-1.6, 1.7, name='gamma1', latex_label='$\\gamma_1')
-priors['eos_spectral_gamma_2'] = bilby.core.prior.Uniform(-0.6, 0.6, name='gamma2', latex_label='$\\gamma_2')
-priors['eos_spectral_gamma_3'] = bilby.core.prior.Uniform(-0.02, 0.02, name='gamma3', latex_label='$\\gamma_3')
+for key in ["mass_1", "mass_2", "lambda_1", "lambda_2", "mass_ratio"]:
+    del priors[key]
+priors["chirp_mass"] = bilby.core.prior.Gaussian(
+    1.215, 0.1, name="chirp_mass", unit="$M_{\\odot}$"
+)
+priors["symmetric_mass_ratio"] = bilby.core.prior.Uniform(
+    0.1, 0.25, name="symmetric_mass_ratio"
+)
+priors["eos_spectral_gamma_0"] = bilby.core.prior.Uniform(
+    0.2, 2.0, name="gamma0", latex_label="$\\gamma_0"
+)
+priors["eos_spectral_gamma_1"] = bilby.core.prior.Uniform(
+    -1.6, 1.7, name="gamma1", latex_label="$\\gamma_1"
+)
+priors["eos_spectral_gamma_2"] = bilby.core.prior.Uniform(
+    -0.6, 0.6, name="gamma2", latex_label="$\\gamma_2"
+)
+priors["eos_spectral_gamma_3"] = bilby.core.prior.Uniform(
+    -0.02, 0.02, name="gamma3", latex_label="$\\gamma_3"
+)
 
 # The eos_check prior imposes several hard physical constraints on samples like
 # enforcing causality and monotinicity of the EoSs. In almost ever conceivable
 # sampling scenario, this should be enabled.
-priors['eos_check'] = bilby.gw.prior.EOSCheck()
+priors["eos_check"] = bilby.gw.prior.EOSCheck()
 
 # Initialise the likelihood by passing in the interferometer data (IFOs)
 # and the waveform generator
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=interferometers, waveform_generator=waveform_generator,
-    time_marginalization=False, phase_marginalization=False,
-    distance_marginalization=False, priors=priors)
+    interferometers=interferometers,
+    waveform_generator=waveform_generator,
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label,
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    npoints=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
     conversion_function=bilby.gw.conversion.generate_all_bns_parameters,
-    resume=True)
+    resume=True,
+)
 
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/calibration_example.py b/examples/gw_examples/injection_examples/calibration_example.py
index 8b9556f48ca43a2419629cc1d4f8fdb1daa25b96..cd5c4da87b9ef79acc53af50167a67c1aa343c19 100644
--- a/examples/gw_examples/injection_examples/calibration_example.py
+++ b/examples/gw_examples/injection_examples/calibration_example.py
@@ -2,20 +2,23 @@
 """
 Tutorial to demonstrate running parameter estimation with calibration
 uncertainties included.
+
+We set up the full problem as is required and then just sample over a small
+number of calibration parameters.
 """
 
-import numpy as np
 import bilby
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment
 # that we're going to create and inject the signal into.
 
-duration = 4.
-sampling_frequency = 2048.
+duration = 4
+sampling_frequency = 1024
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'calibration'
+outdir = "outdir"
+label = "calibration"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -26,37 +29,58 @@ np.random.seed(88170235)
 # parameters, including masses of the two black holes (mass_1, mass_2),
 # spins of both black holes (a, tilt, phi), etc.
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
 # Fixed arguments passed into the source model
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50.0)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    parameters=injection_parameters, waveform_arguments=waveform_arguments)
+    parameters=injection_parameters,
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers. In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).
 # These default to their design sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 for ifo in ifos:
-    injection_parameters.update({
-        'recalib_{}_amplitude_{}'.format(ifo.name, ii): 0.1 for ii in range(5)})
-    injection_parameters.update({
-        'recalib_{}_phase_{}'.format(ifo.name, ii): 0.01 for ii in range(5)})
+    injection_parameters.update(
+        {f"recalib_{ifo.name}_amplitude_{ii}": 0.1 for ii in range(5)}
+    )
+    injection_parameters.update(
+        {f"recalib_{ifo.name}_phase_{ii}": 0.01 for ii in range(5)}
+    )
     ifo.calibration_model = bilby.gw.calibration.CubicSpline(
-        prefix='recalib_{}_'.format(ifo.name),
+        prefix=f"recalib_{ifo.name}_",
         minimum_frequency=ifo.minimum_frequency,
-        maximum_frequency=ifo.maximum_frequency, n_points=5)
+        maximum_frequency=ifo.maximum_frequency,
+        n_points=5,
+    )
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration)
-ifos.inject_signal(parameters=injection_parameters,
-                   waveform_generator=waveform_generator)
+    sampling_frequency=sampling_frequency, duration=duration
+)
+ifos.inject_signal(
+    parameters=injection_parameters, waveform_generator=waveform_generator
+)
 
 # Set up prior, which is a dictionary
 # Here we fix the injected cbc parameters and most of the calibration parameters
@@ -64,21 +88,30 @@ ifos.inject_signal(parameters=injection_parameters,
 # We allow a subset of the calibration parameters to vary.
 priors = injection_parameters.copy()
 for key in injection_parameters:
-    if 'recalib' in key:
+    if "recalib" in key:
         priors[key] = injection_parameters[key]
-for name in ['recalib_H1_amplitude_0', 'recalib_H1_amplitude_1']:
-    priors[name] = bilby.prior.Gaussian(
-        mu=0, sigma=0.2, name=name, latex_label='H1 $A_{}$'.format(name[-1]))
+for name in ["recalib_H1_amplitude_0", "recalib_H1_amplitude_1"]:
+    priors[name] = bilby.core.prior.Gaussian(
+        mu=0, sigma=0.2, name=name, latex_label=f"H1 $A_{name[-1]}$"
+    )
 
 # Initialise the likelihood by passing in the interferometer data (IFOs) and
 # the waveform generator
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    npoints=1000,
+    sample="unif",
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # make some plots of the outputs
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/calibration_marginalization_example.py b/examples/gw_examples/injection_examples/calibration_marginalization_example.py
index 73279aadbf4435d59d51ce08b33d59569e475028..2f8dc326a04299df5dba5168097f67dcf04723fb 100644
--- a/examples/gw_examples/injection_examples/calibration_marginalization_example.py
+++ b/examples/gw_examples/injection_examples/calibration_marginalization_example.py
@@ -4,19 +4,20 @@ Tutorial to demonstrate running parameter estimation with calibration
 uncertainties marginalized over using a finite set of realizations.
 """
 
-import numpy as np
+from copy import deepcopy
+
 import bilby
-from copy import copy
 import matplotlib.pyplot as plt
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment
 # that we're going to create and inject the signal into.
-duration = 4.
-sampling_frequency = 2048.
+duration = 4
+sampling_frequency = 1024
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'calibration_marginalization'
+outdir = "outdir"
+label = "calibration_marginalization"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -27,110 +28,169 @@ np.random.seed(170817)
 # parameters, including masses of the two black holes (mass_1, mass_2),
 # spins of both black holes (a, tilt, phi), etc.
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
-start_time = injection_parameters['geocent_time'] - duration + 2
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
+start_time = injection_parameters["geocent_time"] - duration + 2
 
 # Fixed arguments passed into the source model
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50.0)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    parameters=injection_parameters, waveform_arguments=waveform_arguments)
-waveform_generator_rew = copy(waveform_generator)
+    parameters=injection_parameters,
+    waveform_arguments=waveform_arguments,
+)
+waveform_generator_rew = deepcopy(waveform_generator)
 
 # Set up interferometers. In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).
 # These default to their design sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 for ifo in ifos:
-    injection_parameters.update({
-        'recalib_{}_amplitude_{}'.format(ifo.name, ii): 0.0 for ii in range(10)})
-    injection_parameters.update({
-        'recalib_{}_phase_{}'.format(ifo.name, ii): 0.0 for ii in range(10)})
+    injection_parameters.update(
+        {f"recalib_{ifo.name}_amplitude_{ii}": 0.0 for ii in range(10)}
+    )
+    injection_parameters.update(
+        {f"recalib_{ifo.name}_phase_{ii}": 0.0 for ii in range(10)}
+    )
     ifo.calibration_model = bilby.gw.calibration.CubicSpline(
-        prefix='recalib_{}_'.format(ifo.name),
+        prefix=f"recalib_{ifo.name}_",
         minimum_frequency=ifo.minimum_frequency,
-        maximum_frequency=ifo.maximum_frequency, n_points=10)
+        maximum_frequency=ifo.maximum_frequency,
+        n_points=10,
+    )
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration, start_time=start_time)
-ifos.inject_signal(parameters=injection_parameters,
-                   waveform_generator=waveform_generator)
-ifos_rew = copy(ifos)
+    sampling_frequency=sampling_frequency, duration=duration, start_time=start_time
+)
+ifos.inject_signal(
+    parameters=injection_parameters, waveform_generator=waveform_generator
+)
+ifos_rew = deepcopy(ifos)
 
 # Set up prior, which is a dictionary
 # Here we fix the injected cbc parameters (except the distance)
 # to the injected values.
 priors = injection_parameters.copy()
-priors['luminosity_distance'] = bilby.prior.Uniform(
-    injection_parameters['luminosity_distance'] - 1000, injection_parameters['luminosity_distance'] + 1000,
-    name='luminosity_distance', latex_label='$d_L$')
+priors["luminosity_distance"] = bilby.prior.Uniform(
+    injection_parameters["luminosity_distance"] - 1000,
+    injection_parameters["luminosity_distance"] + 1000,
+    name="luminosity_distance",
+    latex_label="$d_L$",
+)
 for key in injection_parameters:
-    if 'recalib' in key:
+    if "recalib" in key:
         priors[key] = injection_parameters[key]
 
 # Convert to prior dictionary to replace the floats with delta function priors
 priors = bilby.core.prior.PriorDict(priors)
-priors_rew = copy(priors)
+priors_rew = deepcopy(priors)
 
 # Initialise the likelihood by passing in the interferometer data (IFOs) and
 # the waveform generator. Here we assume there is no calibration uncertainty
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator, priors=priors)
+    interferometers=ifos, waveform_generator=waveform_generator, priors=priors
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    npoints=500,
+    walks=20,
+    nact=3,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # Setting the log likelihood to actually be the log likelihood and not the log likelihood ratio...
 # This is used the for reweighting
-result.posterior['log_likelihood'] = result.posterior['log_likelihood'] + result.log_noise_evidence
+result.posterior["log_likelihood"] = (
+    result.posterior["log_likelihood"] + result.log_noise_evidence
+)
 
 # Setting the priors we want on the calibration response curve parameters - as an example.
-for name in ['recalib_H1_amplitude_1', 'recalib_H1_amplitude_4']:
+for name in ["recalib_H1_amplitude_1", "recalib_H1_amplitude_4"]:
     priors_rew[name] = bilby.prior.Gaussian(
-        mu=0, sigma=0.03, name=name, latex_label='H1 $A_{}$'.format(name[-1]))
+        mu=0, sigma=0.03, name=name, latex_label=f"H1 $A_{name[-1]}$"
+    )
 
 # Setting up the calibration marginalized likelihood.
 # We save the calibration response curve files into the output directory under {ifo.name}_calibration_file.h5
 cal_likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos_rew, waveform_generator=waveform_generator_rew,
-    calibration_marginalization=True, priors=priors_rew,
+    interferometers=ifos_rew,
+    waveform_generator=waveform_generator_rew,
+    calibration_marginalization=True,
+    priors=priors_rew,
     number_of_response_curves=100,
-    calibration_lookup_table={ifos[i].name: f'{outdir}/{ifos[i].name}_calibration_file.h5' for i in range(len(ifos))})
+    calibration_lookup_table={
+        ifos[i].name: f"{outdir}/{ifos[i].name}_calibration_file.h5"
+        for i in range(len(ifos))
+    },
+)
 
 # Plot the magnitude of the curves to be used in the marginalization
-plt.semilogx(ifos[0].frequency_array[ifos[0].frequency_mask][0:-1],
-             cal_likelihood.calibration_draws[ifos[0].name][:, 0:-1].T)
+plt.semilogx(
+    ifos[0].frequency_array[ifos[0].frequency_mask][0:-1],
+    cal_likelihood.calibration_draws[ifos[0].name][:, 0:-1].T,
+)
 plt.xlim(20, 1024)
-plt.ylabel('Magnitude')
-plt.xlabel('Frequency [Hz]')
+plt.ylabel("Magnitude")
+plt.xlabel("Frequency [Hz]")
 plt.savefig(f"{outdir}/calibration_draws.pdf")
 plt.clf()
 
 # Reweight the posterior samples from a distribution with no calibration uncertainty to one with uncertainty.
 # This method utilizes rejection sampling which can be inefficient at drawing samples at higher SNRs.
-result_rew = bilby.core.result.reweight(result,
-                                        new_likelihood=cal_likelihood,
-                                        conversion_function=bilby.gw.conversion.generate_all_bbh_parameters)
+result_rew = bilby.core.result.reweight(
+    result,
+    new_likelihood=cal_likelihood,
+    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters,
+)
 
 # Plot distance posterior with and without the calibration
-for res, label in zip([result, result_rew], ["No calibration uncertainty", "Calibration uncertainty"]):
-    plt.hist(res.posterior['luminosity_distance'], label=label, bins=50, histtype='step', density=True)
+for res, label in zip(
+    [result, result_rew], ["No calibration uncertainty", "Calibration uncertainty"]
+):
+    plt.hist(
+        res.posterior["luminosity_distance"],
+        label=label,
+        bins=50,
+        histtype="step",
+        density=True,
+    )
 plt.legend()
-plt.xlabel('Luminosity distance [Mpc]')
+plt.xlabel("Luminosity distance [Mpc]")
 plt.savefig(f"{outdir}/luminosity_distance_posterior.pdf")
 plt.clf()
 
 plt.hist(
-    result_rew.posterior['recalib_index'],
-    bins=np.linspace(0, cal_likelihood.number_of_response_curves - 1, cal_likelihood.number_of_response_curves),
-    density=True)
+    result_rew.posterior["recalib_index"],
+    bins=np.linspace(
+        0,
+        cal_likelihood.number_of_response_curves - 1,
+        cal_likelihood.number_of_response_curves,
+    ),
+    density=True,
+)
 plt.xlim(0, cal_likelihood.number_of_response_curves - 1)
-plt.xlabel('Calibration index')
+plt.xlabel("Calibration index")
 plt.savefig(f"{outdir}/calibration_index_histogram.pdf")
diff --git a/examples/gw_examples/injection_examples/change_sampled_parameters.py b/examples/gw_examples/injection_examples/change_sampled_parameters.py
index 4e0266d08b6290395d2a5e8dc47bdf0dc6ca9432..10eedb7a6d0c19c3e9fae60a214bd6093bf170ef 100644
--- a/examples/gw_examples/injection_examples/change_sampled_parameters.py
+++ b/examples/gw_examples/injection_examples/change_sampled_parameters.py
@@ -10,71 +10,107 @@ The cosmology is according to the Planck 2015 data release.
 import bilby
 import numpy as np
 
-
 bilby.core.utils.setup_logger(log_level="info")
 
-duration = 4.
-sampling_frequency = 2048.
-outdir = 'outdir'
+duration = 4
+sampling_frequency = 2048
+outdir = "outdir"
+label = "different_parameters"
 
 np.random.seed(151226)
 
 injection_parameters = dict(
-    total_mass=66., mass_ratio=0.9, a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000, theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    total_mass=66.0,
+    mass_ratio=0.9,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50.0)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 # We specify a function which transforms a dictionary of parameters into the
 # appropriate parameters for the source model.
 waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    sampling_frequency=sampling_frequency, duration=duration,
+    sampling_frequency=sampling_frequency,
+    duration=duration,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1', 'K1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1", "K1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # Set up prior
 # Note it is possible to sample in different parameters to those that were
 # injected.
 priors = bilby.gw.prior.BBHPriorDict()
-priors.pop('mass_1')
-priors.pop('mass_2')
-priors.pop('luminosity_distance')
-priors['chirp_mass'] = bilby.prior.Uniform(
-    name='chirp_mass', latex_label='$m_c$', minimum=13, maximum=45,
-    unit='$M_{\\odot}$')
-priors['symmetric_mass_ratio'] = bilby.prior.Uniform(
-    name='symmetric_mass_ratio', latex_label='q', minimum=0.1, maximum=0.25)
-priors['redshift'] = bilby.prior.Uniform(
-    name='redshift', latex_label='$z$', minimum=0, maximum=0.5)
+
+del priors["mass_ratio"]
+priors["symmetric_mass_ratio"] = bilby.prior.Uniform(
+    name="symmetric_mass_ratio", latex_label="q", minimum=0.1, maximum=0.25
+)
+
+del priors["luminosity_distance"]
+priors["redshift"] = bilby.prior.Uniform(
+    name="redshift", latex_label="$z$", minimum=0, maximum=0.5
+)
 # These parameters will not be sampled
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi',
-            'ra', 'dec', 'geocent_time', 'phase']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "psi",
+    "ra",
+    "dec",
+    "geocent_time",
+    "phase",
+]:
     priors[key] = injection_parameters[key]
-priors.pop('theta_jn')
-priors['cos_theta_jn'] = np.cos(injection_parameters['theta_jn'])
+del priors["theta_jn"]
+priors["cos_theta_jn"] = np.cos(injection_parameters["theta_jn"])
 print(priors)
 
 # Initialise GravitationalWaveTransient
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler
 # Note we've added a post-processing conversion function, this will generate
 # many useful additional parameters, e.g., source-frame masses.
 result = bilby.core.sampler.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', outdir=outdir,
-    injection_parameters=injection_parameters, label='DifferentParameters',
-    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    walks=25,
+    nact=5,
+    outdir=outdir,
+    injection_parameters=injection_parameters,
+    label=label,
+    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters,
+)
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/create_your_own_source_model.py b/examples/gw_examples/injection_examples/create_your_own_source_model.py
index f755363f091a12f11df425ccaea3d7d7831090c2..77f576a0bd80cc38aaaef1050c04ac4a63fb7eb4 100644
--- a/examples/gw_examples/injection_examples/create_your_own_source_model.py
+++ b/examples/gw_examples/injection_examples/create_your_own_source_model.py
@@ -6,48 +6,97 @@ import bilby
 import numpy as np
 
 # First set up logging and some output directories and labels
-outdir = 'outdir'
-label = 'create_your_own_source_model'
+outdir = "outdir"
+label = "create_your_own_source_model"
 sampling_frequency = 4096
 duration = 1
 
 
 # Here we define out source model - this is the sine-Gaussian model in the
 # frequency domain.
-def sine_gaussian(f, A, f0, tau, phi0, geocent_time, ra, dec, psi):
-    arg = -(np.pi * tau * (f - f0))**2 + 1j * phi0
-    plus = np.sqrt(np.pi) * A * tau * np.exp(arg) / 2.
+def gaussian(frequency_array, amplitude, f0, tau, phi0):
+    r"""
+    Our custom source model, this is just a Gaussian in frequency with
+    variable global phase.
+
+    .. math::
+
+        \tilde{h}_{\plus}(f) = \frac{A \tau}{2\sqrt{\pi}}}
+        e^{- \pi \tau (f - f_{0})^2 + i \phi_{0}} \\
+        \tilde{h}_{\times}(f) = \tilde{h}_{\plus}(f) e^{i \pi / 2}
+
+
+    Parameters
+    ----------
+    frequency_array: array-like
+        The frequencies to evaluate the model at. This is required for all
+        Bilby source models.
+    amplitude: float
+        An overall amplitude prefactor.
+    f0: float
+        The central frequency.
+    tau: float
+        The damping rate.
+    phi0: float
+        The reference phase.
+
+    Returns
+    -------
+    dict:
+        A dictionary containing "plus" and "cross" entries.
+
+    """
+    arg = -((np.pi * tau * (frequency_array - f0)) ** 2) + 1j * phi0
+    plus = np.sqrt(np.pi) * amplitude * tau * np.exp(arg) / 2.0
     cross = plus * np.exp(1j * np.pi / 2)
-    return {'plus': plus, 'cross': cross}
+    return {"plus": plus, "cross": cross}
 
 
 # We now define some parameters that we will inject
-injection_parameters = dict(A=1e-23, f0=100, tau=1, phi0=0, geocent_time=0,
-                            ra=0, dec=0, psi=0)
+injection_parameters = dict(
+    amplitude=1e-23, f0=100, tau=1, phi0=0, geocent_time=0, ra=0, dec=0, psi=0
+)
 
 # Now we pass our source function to the WaveformGenerator
 waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
-    frequency_domain_source_model=sine_gaussian)
+    duration=duration,
+    sampling_frequency=sampling_frequency,
+    frequency_domain_source_model=gaussian,
+)
 
 # Set up interferometers.
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 0.5,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator,
+    parameters=injection_parameters,
+    raise_error=False,
+)
 
 # Here we define the priors for the search. We use the injection parameters
 # except for the amplitude, f0, and geocent_time
 prior = injection_parameters.copy()
-prior['A'] = bilby.core.prior.LogUniform(minimum=1e-25, maximum=1e-21, name='A')
-prior['f0'] = bilby.core.prior.Uniform(90, 110, 'f')
+prior["amplitude"] = bilby.core.prior.LogUniform(
+    minimum=1e-25, maximum=1e-21, latex_label="$\\mathcal{A}$"
+)
+prior["f0"] = bilby.core.prior.Uniform(90, 110, latex_label="$f_{0}$")
 
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 result = bilby.core.sampler.run_sampler(
-    likelihood, prior, sampler='dynesty', outdir=outdir, label=label,
-    resume=False, sample='unif', injection_parameters=injection_parameters)
+    likelihood,
+    prior,
+    sampler="dynesty",
+    outdir=outdir,
+    label=label,
+    resume=False,
+    sample="unif",
+    injection_parameters=injection_parameters,
+)
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py b/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py
index 11290fde30e52f9d1851654d3ffa4349f697fbe8..1427f8ceb2813a262db031a9e494a0ce580c012f 100644
--- a/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py
+++ b/examples/gw_examples/injection_examples/create_your_own_time_domain_source_model.py
@@ -6,62 +6,114 @@ noise in two interferometers (LIGO Livingston and Hanford at design
 sensitivity), and then recovered.
 """
 
-import numpy as np
 import bilby
+import numpy as np
 
 
 # define the time-domain model
-def time_domain_damped_sinusoid(
-        time, amplitude, damping_time, frequency, phase, t0):
-    """
+def time_domain_damped_sinusoid(time, amplitude, damping_time, frequency, phase, t0):
+    r"""
     This example only creates a linearly polarised signal with only plus
     polarisation.
+
+    .. math::
+
+        h_{\plus}(t) =
+            \Theta(t - t_{0}) A
+            e^{-(t - t_{0}) / \tau}
+            \sin \left( 2 \pi f t + \phi \right)
+
+    Parameters
+    ----------
+    time: array-like
+        The times at which to evaluate the model. This is required for all
+        time-domain models.
+    amplitude: float
+        The peak amplitude.
+    damping_time: float
+        The damping time of the exponential.
+    frequency: float
+        The frequency of the oscillations.
+    phase: float
+        The initial phase of the signal.
+    t0: float
+        The offset of the start of the signal from the start time.
+
+    Returns
+    -------
+    dict:
+        A dictionary containing "plus" and "cross" entries.
+
     """
     plus = np.zeros(len(time))
     tidx = time >= t0
-    plus[tidx] = amplitude * np.exp(-(time[tidx] - t0) / damping_time) *\
-        np.sin(2 * np.pi * frequency * (time[tidx] - t0) + phase)
+    plus[tidx] = (
+        amplitude
+        * np.exp(-(time[tidx] - t0) / damping_time)
+        * np.sin(2 * np.pi * frequency * (time[tidx] - t0) + phase)
+    )
     cross = np.zeros(len(time))
-    return {'plus': plus, 'cross': cross}
+    return {"plus": plus, "cross": cross}
 
 
 # define parameters to inject.
-injection_parameters = dict(amplitude=5e-22, damping_time=0.1, frequency=50,
-                            phase=0, ra=0, dec=0, psi=0, t0=0., geocent_time=0.)
+injection_parameters = dict(
+    amplitude=5e-22,
+    damping_time=0.1,
+    frequency=50,
+    phase=0,
+    ra=0,
+    dec=0,
+    psi=0,
+    t0=0.0,
+    geocent_time=0.0,
+)
 
-duration = 1.0
+duration = 1
 sampling_frequency = 1024
-outdir = 'outdir'
-label = 'time_domain_source_model'
+outdir = "outdir"
+label = "time_domain_source_model"
 
 # call the waveform_generator to create our waveform model.
 waveform = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     time_domain_source_model=time_domain_damped_sinusoid,
-    start_time=injection_parameters['geocent_time'] - 0.5)
+    start_time=injection_parameters["geocent_time"] - 0.5,
+)
 
 # inject the signal into three interferometers
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 0.5)
-ifos.inject_signal(waveform_generator=waveform,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 0.5,
+)
+ifos.inject_signal(
+    waveform_generator=waveform, parameters=injection_parameters, raise_error=False
+)
 
 #  create the priors
 prior = injection_parameters.copy()
-prior['amplitude'] = bilby.core.prior.LogUniform(1e-23, 1e-21, r'$h_0$')
-prior['damping_time'] = bilby.core.prior.Uniform(
-    0.01, 1, r'damping time', unit='$s$')
-prior['frequency'] = bilby.core.prior.Uniform(0, 200, r'frequency', unit='Hz')
-prior['phase'] = bilby.core.prior.Uniform(-np.pi / 2, np.pi / 2, r'$\phi$')
+prior["amplitude"] = bilby.core.prior.LogUniform(1e-23, 1e-21, r"$h_0$")
+prior["damping_time"] = bilby.core.prior.Uniform(0.01, 1, r"damping time", unit="$s$")
+prior["frequency"] = bilby.core.prior.Uniform(0, 200, r"frequency", unit="Hz")
+prior["phase"] = bilby.core.prior.Uniform(-np.pi / 2, np.pi / 2, r"$\phi$")
 
 # define likelihood
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(ifos, waveform)
 
 # launch sampler
 result = bilby.core.sampler.run_sampler(
-    likelihood, prior, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood,
+    prior,
+    sampler="dynesty",
+    npoints=500,
+    walks=5,
+    nact=3,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/custom_proposal_example.py b/examples/gw_examples/injection_examples/custom_proposal_example.py
index 2a19f4ee6944ff7aceff31591dad3c837ce0a73c..47019dc858da3d9cd640395a5174955bb40ac012 100644
--- a/examples/gw_examples/injection_examples/custom_proposal_example.py
+++ b/examples/gw_examples/injection_examples/custom_proposal_example.py
@@ -1,71 +1,105 @@
 #!/usr/bin/env python
 """
 Tutorial for running cpnest with custom jump proposals.
+
+This example takes longer than most to run.
+
+Due to how cpnest creates parallel processes, the multiprocessing start method
+needs to be set on some operating systems.
 """
+import multiprocessing
+
+multiprocessing.set_start_method("fork")  # noqa
 
-import numpy as np
 import bilby.gw.sampler.proposal
+import numpy as np
 from bilby.core.sampler import proposal
 
-
 # The set up here is the same as in fast_tutorial.py. Look there for descriptive explanations.
 
-duration = 4.
-sampling_frequency = 2048.
+duration = 4
+sampling_frequency = 1024
 
-outdir = 'outdir'
-label = 'custom_jump_proposals'
+outdir = "outdir"
+label = "custom_jump_proposals"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 np.random.seed(88170235)
 
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50., minimum_frequency=20.)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomPv2",
+    reference_frequency=50.0,
+    minimum_frequency=20.0,
+)
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    waveform_arguments=waveform_arguments)
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+    waveform_arguments=waveform_arguments,
+)
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 priors = bilby.gw.prior.BBHPriorDict()
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    minimum=injection_parameters['geocent_time'] - 1,
-    maximum=injection_parameters['geocent_time'] + 1,
-    name='geocent_time', latex_label='$t_c$', unit='$s$')
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'geocent_time']:
+for key in ["a_1", "a_2", "tilt_1", "tilt_2", "phi_12", "phi_jl", "geocent_time"]:
     priors[key] = injection_parameters[key]
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Definition of the custom jump proposals. Define a JumpProposalCycle. The first argument is a list
 # of all allowed jump proposals. The second argument is a list of weights for the respective jump
 # proposal.
 
 jump_proposals = proposal.JumpProposalCycle(
-    [proposal.EnsembleWalk(priors=priors),
-     proposal.EnsembleStretch(priors=priors),
-     proposal.DifferentialEvolution(priors=priors),
-     proposal.EnsembleEigenVector(priors=priors),
-     bilby.gw.sampler.proposal.SkyLocationWanderJump(priors=priors),
-     bilby.gw.sampler.proposal.CorrelatedPolarisationPhaseJump(priors=priors),
-     bilby.gw.sampler.proposal.PolarisationPhaseJump(priors=priors),
-     proposal.DrawFlatPrior(priors=priors)],
-    weights=[2, 2, 5, 1, 1, 1, 1, 1])
+    [
+        proposal.EnsembleWalk(priors=priors),
+        proposal.EnsembleStretch(priors=priors),
+        proposal.DifferentialEvolution(priors=priors),
+        proposal.EnsembleEigenVector(priors=priors),
+        bilby.gw.sampler.proposal.SkyLocationWanderJump(priors=priors),
+        bilby.gw.sampler.proposal.CorrelatedPolarisationPhaseJump(priors=priors),
+        bilby.gw.sampler.proposal.PolarisationPhaseJump(priors=priors),
+        proposal.DrawFlatPrior(priors=priors),
+    ],
+    weights=[2, 2, 5, 1, 1, 1, 1, 1],
+)
 
 # Run cpnest with the proposals kwarg specified.
 # Make sure to have a version of cpnest installed that supports custom proposals.
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='cpnest', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label,
-    proposals=jump_proposals)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="cpnest",
+    npoints=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+    proposals=jump_proposals,
+)
 
 # Make a corner plot.
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/eccentric_inspiral.py b/examples/gw_examples/injection_examples/eccentric_inspiral.py
index d081cb575bc15855f1e05d06e5ec0426c62ac82d..55a5853ea4e7f135ef5bb37b7aa1e3c79e775cdd 100644
--- a/examples/gw_examples/injection_examples/eccentric_inspiral.py
+++ b/examples/gw_examples/injection_examples/eccentric_inspiral.py
@@ -6,84 +6,121 @@ similar to GW150914.
 
 This uses the same binary parameters that were used to make Figures 1, 2 & 5 in
 Lower et al. (2018) -> arXiv:1806.05350.
-
-For a more comprehensive look at what goes on in each step, refer to the
-"basic_tutorial.py" example.
 """
 
-import numpy as np
 import bilby
+import numpy as np
 
-duration = 64.
-sampling_frequency = 256.
+duration = 64
+sampling_frequency = 256
 
-outdir = 'outdir'
-label = 'eccentric_GW140914'
+outdir = "outdir"
+label = "eccentric_GW150914"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.
 np.random.seed(150914)
 
 injection_parameters = dict(
-    mass_1=35., mass_2=30., eccentricity=0.1, luminosity_distance=440.,
-    theta_jn=0.4, psi=0.1, phase=1.2, geocent_time=1180002601.0, ra=45, dec=5.73)
-
-waveform_arguments = dict(waveform_approximant='EccentricFD',
-                          reference_frequency=10., minimum_frequency=10.)
+    mass_1=35.0,
+    mass_2=30.0,
+    eccentricity=0.1,
+    luminosity_distance=440.0,
+    theta_jn=0.4,
+    psi=0.1,
+    phase=1.2,
+    geocent_time=1180002601.0,
+    ra=45,
+    dec=5.73,
+)
+
+waveform_arguments = dict(
+    waveform_approximant="EccentricFD", reference_frequency=10.0, minimum_frequency=10.0
+)
 
 # Create the waveform_generator using the LAL eccentric black hole no spins
 # source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_eccentric_binary_black_hole_no_spins,
-    parameters=injection_parameters, waveform_arguments=waveform_arguments)
+    parameters=injection_parameters,
+    waveform_arguments=waveform_arguments,
+)
 
 
 # Setting up three interferometers (LIGO-Hanford (H1), LIGO-Livingston (L1), and
 # Virgo (V1)) at their design sensitivities. The maximum frequency is set just
 # prior to the point at which the waveform model terminates. This is to avoid
 # any biases introduced from using a sharply terminating waveform model.
-minimum_frequency = 10.
-maximum_frequency = 128.
+minimum_frequency = 10.0
+maximum_frequency = 128.0
 
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 for ifo in ifos:
     ifo.minimum_frequency = minimum_frequency
     ifo.maximum_frequency = maximum_frequency
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] + 2 - duration,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # Now we set up the priors on each of the binary parameters.
 priors = bilby.core.prior.PriorDict()
 priors["mass_1"] = bilby.core.prior.Uniform(
-    name='mass_1', minimum=5, maximum=60, unit='$M_{\\odot}$')
+    name="mass_1", minimum=5, maximum=60, unit="$M_{\\odot}$", latex_label="$m_1$"
+)
 priors["mass_2"] = bilby.core.prior.Uniform(
-    name='mass_2', minimum=5, maximum=60, unit='$M_{\\odot}$')
+    name="mass_2", minimum=5, maximum=60, unit="$M_{\\odot}$", latex_label="$m_2$"
+)
 priors["eccentricity"] = bilby.core.prior.LogUniform(
-    name='eccentricity', latex_label='$e$', minimum=1e-4, maximum=0.4)
-priors["luminosity_distance"] = bilby.gw.prior.UniformComovingVolume(
-    name='luminosity_distance', minimum=1e2, maximum=2e3)
-priors["dec"] = bilby.core.prior.Cosine(name='dec')
+    name="eccentricity", latex_label="$e$", minimum=1e-4, maximum=0.4
+)
+priors["luminosity_distance"] = bilby.gw.prior.UniformSourceFrame(
+    name="luminosity_distance", minimum=1e2, maximum=2e3
+)
+priors["dec"] = bilby.core.prior.Cosine(name="dec")
 priors["ra"] = bilby.core.prior.Uniform(
-    name='ra', minimum=0, maximum=2 * np.pi)
-priors["theta_jn"] = bilby.core.prior.Sine(name='theta_jn')
-priors["psi"] = bilby.core.prior.Uniform(name='psi', minimum=0, maximum=np.pi)
+    name="ra", minimum=0, maximum=2 * np.pi, boundary="periodic"
+)
+priors["theta_jn"] = bilby.core.prior.Sine(name="theta_jn")
+priors["psi"] = bilby.core.prior.Uniform(
+    name="psi", minimum=0, maximum=np.pi, boundary="periodic"
+)
 priors["phase"] = bilby.core.prior.Uniform(
-    name='phase', minimum=0, maximum=2 * np.pi)
+    name="phase", minimum=0, maximum=2 * np.pi, boundary="periodic"
+)
 priors["geocent_time"] = bilby.core.prior.Uniform(
-    1180002600.9, 1180002601.1, name='geocent_time', unit='s')
+    injection_parameters["geocent_time"] - 0.1,
+    injection_parameters["geocent_time"] + 0.1,
+    name="geocent_time",
+    unit="s",
+)
 
 # Initialising the likelihood function.
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos,
+    waveform_generator=waveform_generator,
+    priors=priors,
+    time_marginalization=True,
+    distance_marginalization=True,
+    phase_marginalization=True,
+)
 
 # Now we run sampler (PyMultiNest in our case).
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='pymultinest', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="pymultinest",
+    npoints=500,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # And finally we make some plots of the output posteriors.
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/fake_sampler_example.py b/examples/gw_examples/injection_examples/fake_sampler_example.py
index e1e7d600d88be737405d12dcdd21be3d5fbcba90..e4632ccb633d6306a8cae177cbe26b7c17a6252f 100755
--- a/examples/gw_examples/injection_examples/fake_sampler_example.py
+++ b/examples/gw_examples/injection_examples/fake_sampler_example.py
@@ -1,119 +1,143 @@
 #!/usr/bin/env python
 """
-Read ROQ posterior and calculate full likelihood at same parameter space points.
+Demonstrate using the FakeSampler to reweight a result to include higher-order
+emission modes. This is a simplified version of the method presented in
+arXiv:1905.05477, however, the method can be applied to a much wider range of
+initial and more complex likelihoods.
 """
 
-import numpy as np
-import deepdish as dd
 import bilby
 import matplotlib.pyplot as plt
+import numpy as np
 
 
-def make_comparison_histograms(file_full, file_roq):
-    # Returns a dictionary
-    data_full = dd.io.load(file_full)
-    data_roq = dd.io.load(file_roq)
-
-    # These are pandas dataframes
-    pos_full = data_full['posterior']
-    pos_roq = data_roq['posterior']
+def make_comparison_histograms(result_1, result_2):
+    pos_full = result_1.posterior
+    pos_simple = result_2.posterior
 
     plt.figure()
-    plt.hist(pos_full['log_likelihood_evaluations'], 50, label='full', histtype='step')
-    plt.hist(pos_roq['log_likelihood_evaluations'], 50, label='roq', histtype='step')
-    plt.xlabel(r'delta_logl')
+    plt.hist(
+        pos_full["log_likelihood"],
+        50,
+        label=result_1.label,
+        histtype="step",
+        density=True,
+    )
+    plt.hist(
+        pos_simple["log_likelihood"],
+        50,
+        label=result_2.label,
+        histtype="step",
+        density=True,
+    )
+    plt.xlabel(r"delta_logl")
     plt.legend(loc=2)
-    plt.savefig('delta_logl.pdf')
-    plt.close()
-
-    plt.figure()
-    delta_dlogl = pos_full['log_likelihood_evaluations'] - pos_roq['log_likelihood_evaluations']
-    plt.hist(delta_dlogl, 50)
-    plt.xlabel(r'delta_logl_full - delta_logl_roq')
-    plt.savefig('delta_delta_logl.pdf')
-    plt.close()
-
-    plt.figure()
-    delta_dlogl = np.abs(pos_full['log_likelihood_evaluations'] - pos_roq['log_likelihood_evaluations'])
-    bins = np.logspace(np.log10(delta_dlogl.min()), np.log10(delta_dlogl.max()), 25)
-    plt.hist(delta_dlogl, bins=bins)
-    plt.xscale('log')
-    plt.xlabel(r'|delta_logl_full - delta_logl_roq|')
-    plt.savefig('log_abs_delta_delta_logl.pdf')
+    plt.savefig(f"{result_1.outdir}/delta_logl.pdf")
     plt.close()
 
 
 def main():
-    outdir = 'outdir_full'
-    label = 'full'
+    outdir = "outdir"
 
     np.random.seed(170808)
 
     duration = 4
-    sampling_frequency = 2048
-    noise = 'zero'
-
-    sampler = 'fake_sampler'
-    # This example assumes that the following posterior file exists.
-    # It comes from a run using the full likelihood using the same
-    # injection and sampling parameters, but the ROQ likelihood.
-    # See roq_example.py for such an example.
-    sample_file = 'outdir_dynesty_zero_noise_SNR22/roq_result.h5'
+    sampling_frequency = 1024
 
     injection_parameters = dict(
-        chirp_mass=36., mass_ratio=0.9, a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0,
-        phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., iota=0.4, psi=0.659,
-        phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
-
-    waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                              reference_frequency=20.0, minimum_frequency=20.0)
+        chirp_mass=36.0,
+        mass_ratio=0.2,
+        chi_1=0.4,
+        chi_2=0.3,
+        luminosity_distance=2000.0,
+        theta_jn=0.4,
+        psi=0.659,
+        phase=1.3,
+        geocent_time=1126259642.413,
+        ra=1.375,
+        dec=-1.2108,
+    )
+
+    waveform_arguments = dict(
+        waveform_approximant="IMRPhenomXAS",
+        reference_frequency=20.0,
+        minimum_frequency=20.0,
+    )
 
     waveform_generator = bilby.gw.WaveformGenerator(
-        duration=duration, sampling_frequency=sampling_frequency,
+        duration=duration,
+        sampling_frequency=sampling_frequency,
         frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
         waveform_arguments=waveform_arguments,
-        parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
-
-    ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
-
-    if noise == 'Gaussian':
-        ifos.set_strain_data_from_power_spectral_densities(
-            sampling_frequency=sampling_frequency, duration=duration,
-            start_time=injection_parameters['geocent_time'] - 3)
-    elif noise == 'zero':
-        ifos.set_strain_data_from_zero_noise(
-            sampling_frequency=sampling_frequency, duration=duration,
-            start_time=injection_parameters['geocent_time'] - 3)
-
-    ifos.inject_signal(waveform_generator=waveform_generator,
-                       parameters=injection_parameters)
-
-    priors = bilby.gw.prior.BBHPriorDict()
-    for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'iota', 'psi', 'ra',
-                'dec', 'phi_12', 'phi_jl', 'luminosity_distance']:
+        parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
+    )
+
+    ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
+
+    ifos.set_strain_data_from_zero_noise(
+        sampling_frequency=sampling_frequency,
+        duration=duration,
+        start_time=injection_parameters["geocent_time"] - 2,
+    )
+
+    ifos.inject_signal(
+        waveform_generator=waveform_generator,
+        parameters=injection_parameters,
+    )
+
+    priors = bilby.gw.prior.BBHPriorDict(aligned_spin=True)
+    for key in [
+        "luminosity_distance",
+        "theta_jn",
+        "phase",
+        "psi",
+        "ra",
+        "dec",
+        "geocent_time",
+    ]:
         priors[key] = injection_parameters[key]
-    priors.pop('mass_1')
-    priors.pop('mass_2')
-    priors['chirp_mass'] = bilby.core.prior.Uniform(
-        15, 40, latex_label='$\\mathcal{M}$')
-    priors['mass_ratio'] = bilby.core.prior.Uniform(0.5, 1, latex_label='$q$')
-    priors['geocent_time'] = bilby.core.prior.Uniform(
-        injection_parameters['geocent_time'] - 0.1,
-        injection_parameters['geocent_time'] + 0.1, latex_label='$t_c$', unit='s')
-
-    likelihood = bilby.gw.GravitationalWaveTransient(
-        interferometers=ifos, waveform_generator=waveform_generator)
-
-    result = bilby.run_sampler(
-        likelihood=likelihood, priors=priors, sampler=sampler, sample_file=sample_file,
-        injection_parameters=injection_parameters, outdir=outdir, label=label)
-
-    # Make a corner plot.
-    result.plot_corner()
-
-    # Compare full and ROQ likelihoods
-    make_comparison_histograms(outdir + '/%s_result.h5' % label, sample_file)
-
-
-if __name__ == '__main__':
+    priors["chirp_mass"].maximum = 45
+
+    likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+        interferometers=ifos, waveform_generator=waveform_generator
+    )
+
+    # perform the initial sampling with our simple model
+    original_result = bilby.run_sampler(
+        likelihood=likelihood,
+        priors=priors,
+        sampler="dynesty",
+        walks=10,
+        nact=3,
+        bound="single",
+        injection_parameters=injection_parameters,
+        outdir=outdir,
+        label="primary_mode_only",
+        save="hdf5",
+    )
+
+    # update the waveform generator to use our higher-order mode waveform
+    likelihood.waveform_generator.waveform_arguments[
+        "waveform_approximant"
+    ] = "IMRPhenomXHM"
+
+    # call the FakeSampler to compute the new likelihoods
+    new_result = bilby.run_sampler(
+        likelihood=likelihood,
+        priors=priors,
+        sampler="fake_sampler",
+        sample_file=f"{outdir}/{original_result.label}_result.hdf5",
+        injection_parameters=injection_parameters,
+        outdir=outdir,
+        verbose=False,
+        label="higher_order_mode",
+        save="hdf5",
+    )
+
+    # make some comparison plots
+    bilby.core.result.plot_multiple([original_result, new_result])
+    make_comparison_histograms(new_result, original_result)
+
+
+if __name__ == "__main__":
     main()
diff --git a/examples/gw_examples/injection_examples/fast_tutorial.py b/examples/gw_examples/injection_examples/fast_tutorial.py
index 3a2e18fc135b6225e369e28643d5749d0b9db50d..f462008a70994982d47c9cbd69179161c96aae3c 100644
--- a/examples/gw_examples/injection_examples/fast_tutorial.py
+++ b/examples/gw_examples/injection_examples/fast_tutorial.py
@@ -8,18 +8,18 @@ and distance using a uniform in comoving volume prior on luminosity distance
 between luminosity distances of 100Mpc and 5Gpc, the cosmology is Planck15.
 """
 
-import numpy as np
 import bilby
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment that we're
 # going to inject the signal into
-duration = 4.
-sampling_frequency = 2048.
+duration = 4.0
+sampling_frequency = 2048.0
 minimum_frequency = 20
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'fast_tutorial'
+outdir = "outdir"
+label = "fast_tutorial"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -30,31 +30,51 @@ np.random.seed(88170235)
 # parameters, including masses of the two black holes (mass_1, mass_2),
 # spins of both black holes (a, tilt, phi), etc.
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
 # Fixed arguments passed into the source model
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.,
-                          minimum_frequency=minimum_frequency)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomPv2",
+    reference_frequency=50.0,
+    minimum_frequency=minimum_frequency,
+)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.  In this case we'll use two interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design
 # sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # Set up a PriorDict, which inherits from dict.
 # By default we will sample all terms in the signal models.  However, this will
@@ -67,12 +87,19 @@ ifos.inject_signal(waveform_generator=waveform_generator,
 # distance, which means those are the parameters that will be included in the
 # sampler.  If we do nothing, then the default priors get used.
 priors = bilby.gw.prior.BBHPriorDict()
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    minimum=injection_parameters['geocent_time'] - 0.1,
-    maximum=injection_parameters['geocent_time'] + 0.1,
-    name='geocent_time', latex_label='$t_c$', unit='$s$')
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi', 'ra',
-            'dec', 'geocent_time', 'phase']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "psi",
+    "ra",
+    "dec",
+    "geocent_time",
+    "phase",
+]:
     priors[key] = injection_parameters[key]
 
 # Perform a check that the prior does not extend to a parameter space longer than the data
@@ -81,12 +108,19 @@ priors.validate_prior(duration, minimum_frequency)
 # Initialise the likelihood by passing in the interferometer data (ifos) and
 # the waveform generator
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    npoints=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # Make a corner plot.
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/how_to_specify_the_prior.py b/examples/gw_examples/injection_examples/how_to_specify_the_prior.py
index 851a3aab37cfbc1ce915af5e80e104dddf4b3a80..816288c5a727dc482c1524f7357e638004271228 100644
--- a/examples/gw_examples/injection_examples/how_to_specify_the_prior.py
+++ b/examples/gw_examples/injection_examples/how_to_specify_the_prior.py
@@ -4,58 +4,94 @@ Tutorial to demonstrate how to specify the prior distributions used for
 parameter estimation.
 """
 
-import numpy as np
 import bilby
+import numpy as np
 
-
-duration = 4.
-sampling_frequency = 2048.
-outdir = 'outdir'
+duration = 4
+sampling_frequency = 1024
+outdir = "outdir"
 
 np.random.seed(151012)
 
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=4000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=4000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50., minimum_frequency=20.)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomXPHM",
+    reference_frequency=50.0,
+    minimum_frequency=20.0,
+)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # Set up prior
 # This loads in a predefined set of priors for BBHs.
 priors = bilby.gw.prior.BBHPriorDict()
 # These parameters will not be sampled
-for key in ['tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'phase', 'theta_jn', 'ra',
-            'dec', 'geocent_time', 'psi']:
+for key in [
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "phase",
+    "theta_jn",
+    "ra",
+    "dec",
+    "geocent_time",
+    "psi",
+]:
     priors[key] = injection_parameters[key]
 # We can make uniform distributions.
-priors['mass_2'] = bilby.core.prior.Uniform(
-    name='mass_2', minimum=20, maximum=40, unit='$M_{\\odot}$')
+del priors["chirp_mass"], priors["mass_ratio"]
+# We can make uniform distributions.
+priors["mass_1"] = bilby.core.prior.Uniform(
+    name="mass_1", minimum=20, maximum=40, unit="$M_{\\odot}$"
+)
+priors["mass_2"] = bilby.core.prior.Uniform(
+    name="mass_2", minimum=20, maximum=40, unit="$M_{\\odot}$"
+)
 # We can make a power-law distribution, p(x) ~ x^{alpha}
 # Note: alpha=0 is a uniform distribution, alpha=-1 is uniform-in-log
-priors['a_1'] = bilby.core.prior.PowerLaw(
-    name='a_1', alpha=-1, minimum=1e-2, maximum=1)
+priors["a_1"] = bilby.core.prior.PowerLaw(name="a_1", alpha=-1, minimum=1e-2, maximum=1)
 # We can define a prior from an array as follows.
 # Note: this doesn't have to be properly normalised.
 a_2 = np.linspace(0, 1, 1001)
-p_a_2 = a_2 ** 4
-priors['a_2'] = bilby.core.prior.Interped(
-    name='a_2', xx=a_2, yy=p_a_2, minimum=0, maximum=0.5)
+p_a_2 = a_2**4
+priors["a_2"] = bilby.core.prior.Interped(
+    name="a_2", xx=a_2, yy=p_a_2, minimum=0, maximum=0.5
+)
 # Additionally, we have Gaussian, TruncatedGaussian, Sine and Cosine.
 # It's also possible to load an interpolate a prior from a file.
 # Finally, if you don't specify any necessary parameters it will be filled in
@@ -64,10 +100,16 @@ priors['a_2'] = bilby.core.prior.Interped(
 
 # Initialise GravitationalWaveTransient
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', outdir=outdir,
-    injection_parameters=injection_parameters, label='specify_prior')
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    outdir=outdir,
+    injection_parameters=injection_parameters,
+    label="specify_prior",
+)
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/marginalized_likelihood.py b/examples/gw_examples/injection_examples/marginalized_likelihood.py
index c28efca13d4d9ddda064d4e6eef14c72b8086407..55cd2d4c7774613b5f520cbadccbad915804aa03 100644
--- a/examples/gw_examples/injection_examples/marginalized_likelihood.py
+++ b/examples/gw_examples/injection_examples/marginalized_likelihood.py
@@ -9,55 +9,90 @@ parameter can be recovered in post-processing.
 import bilby
 import numpy as np
 
-
-duration = 4.
-sampling_frequency = 2048.
-outdir = 'outdir'
-label = 'marginalized_likelihood'
+duration = 4
+sampling_frequency = 1024
+outdir = "outdir"
+label = "marginalized_likelihood"
 
 np.random.seed(170608)
 
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=4000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=4000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # Set up prior
 priors = bilby.gw.prior.BBHPriorDict()
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    minimum=injection_parameters['geocent_time'] - 1,
-    maximum=injection_parameters['geocent_time'] + 1,
-    name='geocent_time', latex_label='$t_c$', unit='$s$')
+priors["geocent_time"] = bilby.core.prior.Uniform(
+    minimum=injection_parameters["geocent_time"] - 0.1,
+    maximum=injection_parameters["geocent_time"] + 0.1,
+    name="geocent_time",
+    latex_label="$t_c$",
+    unit="$s$",
+)
 # These parameters will not be sampled
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'theta_jn', 'ra',
-            'dec', 'mass_1', 'mass_2']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "theta_jn",
+    "ra",
+    "dec",
+    "mass_1",
+    "mass_2",
+]:
     priors[key] = injection_parameters[key]
+del priors["chirp_mass"], priors["mass_ratio"]
 
 # Initialise GravitationalWaveTransient
 # Note that we now need to pass the: priors and flags for each thing that's
 # being marginalised. A lookup table is used for distance marginalisation which
 # takes a few minutes to build.
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator, priors=priors,
-    distance_marginalization=True, phase_marginalization=True,
-    time_marginalization=True)
+    interferometers=ifos,
+    waveform_generator=waveform_generator,
+    priors=priors,
+    distance_marginalization=True,
+    phase_marginalization=True,
+    time_marginalization=True,
+)
 
 # Run sampler
 # Note that we've added an additional argument `conversion_function`, this is
@@ -66,7 +101,12 @@ likelihood = bilby.gw.GravitationalWaveTransient(
 # reconstructs posterior distributions for the parameters which were
 # marginalised over in the likelihood.
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty',
-    injection_parameters=injection_parameters, outdir=outdir, label=label,
-    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters,
+)
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/multiband_example.py b/examples/gw_examples/injection_examples/multiband_example.py
index ce6c9f436319a96695d0c9337b3e69701d454bec..e43b6ca2190a399442723639b9fd79ef06c418d6 100644
--- a/examples/gw_examples/injection_examples/multiband_example.py
+++ b/examples/gw_examples/injection_examples/multiband_example.py
@@ -4,72 +4,104 @@ Example of how to use the multi-banding method (see Morisaki, (2021) arXiv:
 2104.07813) for a binary neutron star simulated signal in Gaussian noise.
 """
 
-import numpy as np
-
 import bilby
+import numpy as np
 
-outdir = 'outdir'
-label = 'multibanding'
+outdir = "outdir"
+label = "multibanding"
 
 np.random.seed(170808)
 
-minimum_frequency = 20.
-reference_frequency = 100.
-duration = 256.
-sampling_frequency = 4096.
-approximant = 'IMRPhenomD'
+minimum_frequency = 20
+reference_frequency = 100
+duration = 256
+sampling_frequency = 4096
+approximant = "IMRPhenomD"
 injection_parameters = dict(
-    chirp_mass=1.2, mass_ratio=0.8, chi_1=0., chi_2=0.,
-    ra=3.44616, dec=-0.408084, luminosity_distance=200.,
-    theta_jn=0.4, psi=0.659, phase=1.3, geocent_time=1187008882)
+    chirp_mass=1.2,
+    mass_ratio=0.8,
+    chi_1=0.0,
+    chi_2=0.0,
+    ra=3.44616,
+    dec=-0.408084,
+    luminosity_distance=200.0,
+    theta_jn=0.4,
+    psi=0.659,
+    phase=1.3,
+    geocent_time=1187008882,
+)
 
 # inject signal
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    waveform_arguments=dict(waveform_approximant=approximant,
-                            reference_frequency=reference_frequency),
-    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+    waveform_arguments=dict(
+        waveform_approximant=approximant, reference_frequency=reference_frequency
+    ),
+    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
+)
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - duration + 2.)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - duration + 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 for ifo in ifos:
     ifo.minimum_frequency = minimum_frequency
 
 # make waveform generator for likelihood evaluations
 search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.binary_black_hole_frequency_sequence,
-    waveform_arguments=dict(waveform_approximant=approximant,
-                            reference_frequency=reference_frequency),
-    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
+    waveform_arguments=dict(
+        waveform_approximant=approximant, reference_frequency=reference_frequency
+    ),
+    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
+)
 
 # make prior
 priors = bilby.gw.prior.BNSPriorDict()
-priors['chi_1'] = 0.
-priors['chi_2'] = 0.
-priors.pop('lambda_1')
-priors.pop('lambda_2')
-priors['chirp_mass'] = bilby.core.prior.Uniform(name='chirp_mass', minimum=1.15, maximum=1.25)
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    injection_parameters['geocent_time'] - 0.1,
-    injection_parameters['geocent_time'] + 0.1, latex_label='$t_c$', unit='s')
+priors["chi_1"] = 0
+priors["chi_2"] = 0
+del priors["lambda_1"], priors["lambda_2"]
+priors["chirp_mass"] = bilby.core.prior.Uniform(
+    name="chirp_mass", minimum=1.15, maximum=1.25
+)
+priors["geocent_time"] = bilby.core.prior.Uniform(
+    injection_parameters["geocent_time"] - 0.1,
+    injection_parameters["geocent_time"] + 0.1,
+    latex_label="$t_c$",
+    unit="s",
+)
 
 # make multi-banded likelihood
 likelihood = bilby.gw.likelihood.MBGravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=search_waveform_generator, priors=priors,
-    reference_chirp_mass=priors['chirp_mass'].minimum,
-    distance_marginalization=True, phase_marginalization=True
+    interferometers=ifos,
+    waveform_generator=search_waveform_generator,
+    priors=priors,
+    reference_chirp_mass=priors["chirp_mass"].minimum,
+    distance_marginalization=True,
+    phase_marginalization=True,
 )
 
 # sampling
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty',
-    nlive=500, walks=100, maxmcmc=5000, nact=5,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    nlive=500,
+    walks=100,
+    maxmcmc=5000,
+    nact=5,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # Make a corner plot.
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/non_tensor.py b/examples/gw_examples/injection_examples/non_tensor.py
index 6c44216954f92aab978cac1be612d3ebbc17b042..1020a10002a869afffb9d82384844649b67f8c61 100644
--- a/examples/gw_examples/injection_examples/non_tensor.py
+++ b/examples/gw_examples/injection_examples/non_tensor.py
@@ -28,69 +28,99 @@ def vector_tensor_sine_gaussian(frequency_array, hrss, Q, frequency, epsilon):
         Relative size of the vector modes compared to the tensor modes.
     """
     waveform_polarizations = bilby.gw.source.sinegaussian(
-        frequency_array, hrss, Q, frequency)
+        frequency_array, hrss, Q, frequency
+    )
 
-    waveform_polarizations['x'] = epsilon * waveform_polarizations['plus']
-    waveform_polarizations['y'] = epsilon * waveform_polarizations['cross']
+    waveform_polarizations["x"] = epsilon * waveform_polarizations["plus"]
+    waveform_polarizations["y"] = epsilon * waveform_polarizations["cross"]
     return waveform_polarizations
 
 
-duration = 4.
-sampling_frequency = 2048.
+duration = 1
+sampling_frequency = 512
 
-outdir = 'outdir'
-label = 'vector_tensor'
+outdir = "outdir"
+label = "vector_tensor"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 np.random.seed(170801)
 
 injection_parameters = dict(
-    hrss=1e-22, Q=5.0, frequency=200.0, ra=1.375, dec=-1.2108,
-    geocent_time=1126259642.413, psi=2.659, epsilon=0.2)
-
-waveform_generator =\
-    bilby.gw.waveform_generator.WaveformGenerator(
-        duration=duration, sampling_frequency=sampling_frequency,
-        frequency_domain_source_model=vector_tensor_sine_gaussian)
-
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+    hrss=1e-22,
+    Q=5.0,
+    frequency=200.0,
+    ra=1.375,
+    dec=-1.2108,
+    geocent_time=1126259642.413,
+    psi=2.659,
+    epsilon=0.2,
+)
+
+waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
+    duration=duration,
+    sampling_frequency=sampling_frequency,
+    frequency_domain_source_model=vector_tensor_sine_gaussian,
+)
+
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
-
-priors = dict()
-for key in ['psi', 'geocent_time', 'hrss', 'Q', 'frequency']:
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 0.5,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator,
+    parameters=injection_parameters,
+    raise_error=False,
+)
+
+priors = bilby.core.prior.PriorDict()
+for key in ["psi", "geocent_time", "hrss", "Q", "frequency"]:
     priors[key] = injection_parameters[key]
-priors['ra'] = bilby.core.prior.Uniform(0, 2 * np.pi, latex_label='$\\alpha$')
-priors['dec'] = bilby.core.prior.Cosine(latex_label='$\\delta$')
-priors['epsilon'] = bilby.core.prior.Uniform(0, 1, latex_label='$\\epsilon$')
+priors["ra"] = bilby.core.prior.Uniform(0, 2 * np.pi, latex_label="$\\alpha$")
+priors["dec"] = bilby.core.prior.Cosine(latex_label="$\\delta$")
+priors["epsilon"] = bilby.core.prior.Uniform(0, 1, latex_label="$\\epsilon$")
 
 vector_tensor_likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 vector_tensor_result = bilby.core.sampler.run_sampler(
-    likelihood=vector_tensor_likelihood, priors=priors, sampler='nestle',
-    npoints=1000, injection_parameters=injection_parameters,
-    outdir=outdir, label='vector_tensor')
+    likelihood=vector_tensor_likelihood,
+    priors=priors,
+    sampler="nestle",
+    nlive=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label="vector_tensor",
+)
 
 vector_tensor_result.plot_corner()
 
 tensor_likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
-priors['epsilon'] = 0
+priors["epsilon"] = 0
 
-# Run sampler.  In this case we're going to use the `dynesty` sampler
+# Run sampler.  In this case we're going to use the `nestle` sampler
 tensor_result = bilby.core.sampler.run_sampler(
-    likelihood=tensor_likelihood, priors=priors, sampler='nestle', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label='tensor')
+    likelihood=tensor_likelihood,
+    priors=priors,
+    sampler="nestle",
+    nlive=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label="tensor",
+)
 
 # make some plots of the outputs
 tensor_result.plot_corner()
 
 bilby.result.plot_multiple(
-    [tensor_result, vector_tensor_result], labels=['Tensor', 'Vector + Tensor'],
-    parameters=dict(ra=1.375, dec=-1.2108), evidences=True)
+    [tensor_result, vector_tensor_result],
+    labels=["Tensor", "Vector + Tensor"],
+    parameters=dict(ra=1.375, dec=-1.2108),
+    evidences=True,
+)
diff --git a/examples/gw_examples/injection_examples/plot_skymap.py b/examples/gw_examples/injection_examples/plot_skymap.py
index 4c7ecb9a3e3abbaa4bab6099330b1332d5846909..7d0dbb98eca15dab0f19ea5e83506614e527182c 100644
--- a/examples/gw_examples/injection_examples/plot_skymap.py
+++ b/examples/gw_examples/injection_examples/plot_skymap.py
@@ -5,47 +5,85 @@ skymap
 """
 import bilby
 
-duration = 4.
-sampling_frequency = 2048.
-outdir = 'outdir'
-label = 'plot_skymap'
+duration = 4
+sampling_frequency = 1024
+outdir = "outdir"
+label = "plot_skymap"
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=4000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-0.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=4000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-0.2108,
+)
 
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(waveform_approximant="IMRPhenomXP", reference_frequency=50.0)
 
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    parameters=injection_parameters, waveform_arguments=waveform_arguments)
+    parameters=injection_parameters,
+    waveform_arguments=waveform_arguments,
+)
 
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 priors = bilby.gw.prior.BBHPriorDict()
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'psi',
-            'mass_1', 'mass_2', 'phase', 'geocent_time', 'luminosity_distance',
-            'theta_jn']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "phi_12",
+    "phi_jl",
+    "psi",
+    "mass_1",
+    "mass_2",
+    "phase",
+    "geocent_time",
+    "theta_jn",
+]:
     priors[key] = injection_parameters[key]
+del priors["chirp_mass"], priors["mass_ratio"]
 
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator,
-    time_marginalization=True, phase_marginalization=True,
-    distance_marginalization=False, priors=priors)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label, result_class=bilby.gw.result.CBCResult)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="nestle",
+    npoints=250,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+    result_class=bilby.gw.result.CBCResult,
+)
 
 # make some plots of the outputs
 result.plot_corner()
 
 # will require installation of ligo.skymap (pip install ligo.skymap)
-result.plot_skymap(maxpts=5000)
+# the skymap generation code is fairly slow when using many points so limit
+# ourselves to 500 points in the fit
+result.plot_skymap(maxpts=500)
diff --git a/examples/gw_examples/injection_examples/plot_time_domain_data.py b/examples/gw_examples/injection_examples/plot_time_domain_data.py
index 4e572e8a53a330a0be47a9bfc5f2b50b6879cd0a..707c0f18c101a3d7ac59e3c143a699a7ed1c88db 100644
--- a/examples/gw_examples/injection_examples/plot_time_domain_data.py
+++ b/examples/gw_examples/injection_examples/plot_time_domain_data.py
@@ -1,37 +1,64 @@
 #!/usr/bin/env python
 """
+This example demonstrates how to simulate some data, add an injected signal
+and plot the data.
 """
 import numpy as np
-import bilby
+from bilby.gw.detector import get_empty_interferometer
+from bilby.gw.source import lal_binary_black_hole
+from bilby.gw.waveform_generator import WaveformGenerator
 
 np.random.seed(1)
 
 duration = 4
-sampling_frequency = 2048.
+sampling_frequency = 2048.0
 
-outdir = 'outdir'
-label = 'example'
+outdir = "outdir"
+label = "example"
 
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=1000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=1000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50.)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomTPHM", reference_frequency=50.0
+)
 
-waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
-    frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
-    parameters=injection_parameters, waveform_arguments=waveform_arguments)
+waveform_generator = WaveformGenerator(
+    duration=duration,
+    sampling_frequency=sampling_frequency,
+    frequency_domain_source_model=lal_binary_black_hole,
+    parameters=injection_parameters,
+    waveform_arguments=waveform_arguments,
+)
 hf_signal = waveform_generator.frequency_domain_strain(injection_parameters)
 
-H1 = bilby.gw.detector.get_interferometer_with_fake_noise_and_injection(
-    'H1', injection_polarizations=hf_signal,
-    injection_parameters=injection_parameters, duration=duration,
-    sampling_frequency=sampling_frequency, outdir=outdir)
+ifo = get_empty_interferometer("H1")
+ifo.set_strain_data_from_power_spectral_density(
+    duration=duration, sampling_frequency=sampling_frequency
+)
+ifo.inject_signal(injection_polarizations=hf_signal, parameters=injection_parameters)
 
-t0 = injection_parameters['geocent_time']
-H1.plot_time_domain_data(outdir=outdir, label=label, notches=[50],
-                         bandpass_frequencies=(50, 200), start_end=(-0.5, 0.5),
-                         t0=t0)
+t0 = injection_parameters["geocent_time"]
+ifo.plot_time_domain_data(
+    outdir=outdir,
+    label=label,
+    notches=[50],
+    bandpass_frequencies=(50, 200),
+    start_end=(-0.5, 0.5),
+    t0=t0,
+)
diff --git a/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py b/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py
index 699b56c102d6b4fb462ce1bd2ba16d1f73787899..6d2ee4fc2f1e0bf83391b41f5087ed016b80bfdc 100644
--- a/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py
+++ b/examples/gw_examples/injection_examples/reproduce_mpa1_eos.py
@@ -9,19 +9,25 @@ from bilby.gw import eos
 
 # First, we specify the spectral parameter values for the MPA1 EoS.
 MPA1_gammas = [1.0215, 0.1653, -0.0235, -0.0004]
-MPA1_p0 = 1.51e33                                # Pressure in CGS
-MPA1_e0_c2 = 2.04e14                             # \epsilon_0 / c^2 in CGS
-MPA1_xmax = 6.63                                 # Dimensionless ending pressure
+MPA1_p0 = 1.51e33  # Pressure in CGS
+MPA1_e0_c2 = 2.04e14  # \epsilon_0 / c^2 in CGS
+MPA1_xmax = 6.63  # Dimensionless ending pressure
 
 # Create the spectral decomposition EoS class
-MPA1_spectral = eos.SpectralDecompositionEOS(MPA1_gammas, p0=MPA1_p0, e0=MPA1_e0_c2, xmax=MPA1_xmax, npts=100)
+MPA1_spectral = eos.SpectralDecompositionEOS(
+    MPA1_gammas, p0=MPA1_p0, e0=MPA1_e0_c2, xmax=MPA1_xmax, npts=100
+)
 
 # And create another from tabulated data
-MPA1_tabulated = eos.TabularEOS('MPA1')
+MPA1_tabulated = eos.TabularEOS("MPA1")
 
 # Now let's plot them
 # To do so, we specify a representation and plot ranges.
-MPA1_spectral_plot = MPA1_spectral.plot('pressure-energy_density', xlim=[1e22, 1e36], ylim=[1e9, 1e36])
-MPA1_tabular_plot = MPA1_tabulated.plot('pressure-energy_density', xlim=[1e22, 1e36], ylim=[1e9, 1e36])
-MPA1_spectral_plot.savefig('spectral_mpa1.pdf')
-MPA1_tabular_plot.savefig('tabular_mpa1.pdf')
+MPA1_spectral_plot = MPA1_spectral.plot(
+    "pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36]
+)
+MPA1_tabular_plot = MPA1_tabulated.plot(
+    "pressure-energy_density", xlim=[1e22, 1e36], ylim=[1e9, 1e36]
+)
+MPA1_spectral_plot.savefig("spectral_mpa1.pdf")
+MPA1_tabular_plot.savefig("tabular_mpa1.pdf")
diff --git a/examples/gw_examples/injection_examples/roq_example.py b/examples/gw_examples/injection_examples/roq_example.py
index 82f87d3e418620c5b86f136b2a81b0da7bfec3b7..254e29aad0e15e27075186cdf40ac5ae579b09e1 100644
--- a/examples/gw_examples/injection_examples/roq_example.py
+++ b/examples/gw_examples/injection_examples/roq_example.py
@@ -7,14 +7,16 @@ Gaussian noise.
 This requires files specifying the appropriate basis weights.
 These aren't shipped with Bilby, but are available on LDG clusters and
 from the public repository https://git.ligo.org/lscsoft/ROQ_data.
-"""
 
-import numpy as np
+We also reweight the result using the regular waveform model to check how
+correct the ROQ result is.
+"""
 
 import bilby
+import numpy as np
 
-outdir = 'outdir'
-label = 'roq'
+outdir = "outdir"
+label = "roq"
 
 # The ROQ bases can be given an overall scaling.
 # Note how this is applied to durations, frequencies and masses.
@@ -33,9 +35,9 @@ freq_nodes_quadratic = np.load("fnodes_quadratic.npy") * scale_factor
 params = np.genfromtxt("params.dat", names=True)
 
 # Get scaled ROQ quantities
-minimum_chirp_mass = params['chirpmassmin'] / scale_factor
-maximum_chirp_mass = params['chirpmassmax'] / scale_factor
-minimum_component_mass = params['compmin'] / scale_factor
+minimum_chirp_mass = params["chirpmassmin"] / scale_factor
+maximum_chirp_mass = params["chirpmassmax"] / scale_factor
+minimum_component_mass = params["compmin"] / scale_factor
 
 np.random.seed(170808)
 
@@ -43,73 +45,144 @@ duration = 4 / scale_factor
 sampling_frequency = 2048 * scale_factor
 
 injection_parameters = dict(
-    mass_1=36.0, mass_2=29.0, a_1=0.4, a_2=0.3, tilt_1=0.0, tilt_2=0.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=1000., theta_jn=0.4, psi=0.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
-
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=20. * scale_factor)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.0,
+    tilt_2=0.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=1000.0,
+    theta_jn=0.4,
+    psi=0.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
+
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomPv2", reference_frequency=20.0 * scale_factor
+)
 
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
     waveform_arguments=waveform_arguments,
-    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
+    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
+)
 
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3 / scale_factor)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2 / scale_factor,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 for ifo in ifos:
     ifo.minimum_frequency = 20 * scale_factor
 
 # make ROQ waveform generator
 search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.binary_black_hole_roq,
     waveform_arguments=dict(
         frequency_nodes_linear=freq_nodes_linear,
         frequency_nodes_quadratic=freq_nodes_quadratic,
-        reference_frequency=20. * scale_factor, waveform_approximant='IMRPhenomPv2'),
-    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters)
+        reference_frequency=20.0 * scale_factor,
+        waveform_approximant="IMRPhenomPv2",
+    ),
+    parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
+)
 
 # Here we add constraints on chirp mass and mass ratio to the prior, these are
 # determined by the domain of validity of the ROQ basis.
 priors = bilby.gw.prior.BBHPriorDict()
-for key in ['a_1', 'a_2', 'tilt_1', 'tilt_2', 'theta_jn', 'phase', 'psi', 'ra',
-            'dec', 'phi_12', 'phi_jl', 'luminosity_distance']:
+for key in [
+    "a_1",
+    "a_2",
+    "tilt_1",
+    "tilt_2",
+    "theta_jn",
+    "phase",
+    "psi",
+    "ra",
+    "dec",
+    "phi_12",
+    "phi_jl",
+    "luminosity_distance",
+]:
     priors[key] = injection_parameters[key]
-for key in ['mass_1', 'mass_2']:
+for key in ["mass_1", "mass_2"]:
     priors[key].minimum = max(priors[key].minimum, minimum_component_mass)
-priors['chirp_mass'] = bilby.core.prior.Uniform(
-    name='chirp_mass', minimum=float(minimum_chirp_mass),
-    maximum=float(maximum_chirp_mass))
-priors['mass_ratio'] = bilby.core.prior.Uniform(0.125, 1, name='mass_ratio')
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    injection_parameters['geocent_time'] - 0.1,
-    injection_parameters['geocent_time'] + 0.1, latex_label='$t_c$', unit='s')
+priors["chirp_mass"] = bilby.core.prior.Uniform(
+    name="chirp_mass",
+    minimum=float(minimum_chirp_mass),
+    maximum=float(maximum_chirp_mass),
+)
+# The roq parameters typically store the mass ratio bounds as m1/m2 not m2/m1 as in the
+# Bilby convention.
+priors["mass_ratio"] = bilby.core.prior.Uniform(
+    1 / params["qmax"], 1, name="mass_ratio"
+)
+priors["geocent_time"] = bilby.core.prior.Uniform(
+    injection_parameters["geocent_time"] - 0.1,
+    injection_parameters["geocent_time"] + 0.1,
+    latex_label="$t_c$",
+    unit="s",
+)
 
 likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=search_waveform_generator,
-    linear_matrix=basis_matrix_linear, quadratic_matrix=basis_matrix_quadratic,
-    priors=priors, roq_params=params, roq_scale_factor=scale_factor)
+    interferometers=ifos,
+    waveform_generator=search_waveform_generator,
+    linear_matrix=basis_matrix_linear,
+    quadratic_matrix=basis_matrix_quadratic,
+    priors=priors,
+    roq_params=params,
+    roq_scale_factor=scale_factor,
+)
 
 # write the weights to file so they can be loaded multiple times
-likelihood.save_weights('weights.npz')
+likelihood.save_weights("weights.npz")
 
 # remove the basis matrices as these are big for longer bases
 del basis_matrix_linear, basis_matrix_quadratic
 
 # load the weights from the file
 likelihood = bilby.gw.likelihood.ROQGravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=search_waveform_generator,
-    weights='weights.npz', priors=priors)
+    interferometers=ifos,
+    waveform_generator=search_waveform_generator,
+    weights="weights.npz",
+    priors=priors,
+)
 
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
-
-# Make a corner plot.
-result.plot_corner()
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    npoints=500,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
+
+# Resample the result using the full waveform model with the FakeSampler.
+# This will give us an idea of how good a job the ROQ does.
+full_likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+    interferometers=ifos, waveform_generator=waveform_generator
+)
+resampled_result = bilby.run_sampler(
+    likelihood=full_likelihood,
+    priors=priors,
+    sampler="fake_sampler",
+    label="roq_resampled",
+    outdir=outdir,
+)
+
+# Make a comparison corner plot with the two likelihoods.
+bilby.core.result.plot_multiple([result, resampled_result], labels=["ROQ", "Regular"])
diff --git a/examples/gw_examples/injection_examples/sine_gaussian_example.py b/examples/gw_examples/injection_examples/sine_gaussian_example.py
index d40cb68e89e37a28409c8a292dad4176117a9c20..d253826ea6129ec56a9619b64d097a147b179504 100644
--- a/examples/gw_examples/injection_examples/sine_gaussian_example.py
+++ b/examples/gw_examples/injection_examples/sine_gaussian_example.py
@@ -8,12 +8,12 @@ import numpy as np
 
 # Set the duration and sampling frequency of the data segment that we're going
 # to inject the signal into
-duration = 4.
-sampling_frequency = 2048.
+duration = 1
+sampling_frequency = 512
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'sine_gaussian'
+outdir = "outdir"
+label = "sine_gaussian"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -23,42 +23,64 @@ np.random.seed(170801)
 # dictionary of parameters that includes all of the different waveform
 # parameters
 injection_parameters = dict(
-    hrss=1e-22, Q=5.0, frequency=200.0, ra=1.375, dec=-1.2108,
-    geocent_time=1126259642.413, psi=2.659)
+    hrss=1e-22,
+    Q=5.0,
+    frequency=200.0,
+    ra=1.375,
+    dec=-1.2108,
+    geocent_time=1126259642.413,
+    psi=2.659,
+)
 
 # Create the waveform_generator using a sine Gaussian source function
 waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
-    frequency_domain_source_model=bilby.gw.source.sinegaussian)
+    duration=duration,
+    sampling_frequency=sampling_frequency,
+    frequency_domain_source_model=bilby.gw.source.sinegaussian,
+)
 
 # Set up interferometers.  In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).  These default to
 # their design sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1', 'V1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1", "V1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 0.5,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator,
+    parameters=injection_parameters,
+    raise_error=False,
+)
 
-# Set up prior, which is a dictionary
-priors = dict()
-for key in ['psi', 'ra', 'dec', 'geocent_time']:
+# Set up the prior. We will fix the "extrinsic" parameters to their true values.
+priors = bilby.core.prior.PriorDict()
+for key in ["psi", "ra", "dec", "geocent_time"]:
     priors[key] = injection_parameters[key]
 
-priors['Q'] = bilby.core.prior.Uniform(2, 50, 'Q')
-priors['frequency'] = bilby.core.prior.Uniform(30, 1000, 'frequency', unit='Hz')
-priors['hrss'] = bilby.core.prior.Uniform(1e-23, 1e-21, 'hrss')
+priors["Q"] = bilby.core.prior.Uniform(2, 50, "Q")
+priors["frequency"] = bilby.core.prior.Uniform(160, 240, "frequency", unit="Hz")
+priors["hrss"] = bilby.core.prior.Uniform(1e-23, 1e-21, "hrss")
 
 # Initialise the likelihood by passing in the interferometer data (IFOs) and
 # the waveoform generator
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator)
+    interferometers=ifos, waveform_generator=waveform_generator
+)
 
 # Run sampler.  In this case we're going to use the `dynesty` sampler
 result = bilby.core.sampler.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=1000,
-    injection_parameters=injection_parameters, outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    nlive=1000,
+    walks=10,
+    nact=5,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+)
 
 # make some plots of the outputs
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py b/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py
index 0e084c051a740537621b2c1e55dfeb88c62708fe..db8222e1e7b86425905225b5483d4b298546b2ed 100644
--- a/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py
+++ b/examples/gw_examples/injection_examples/standard_15d_cbc_tutorial.py
@@ -3,18 +3,20 @@
 Tutorial to demonstrate running parameter estimation on a full 15 parameter
 space for an injected cbc signal. This is the standard injection analysis script
 one can modify for the study of injected CBC events.
+
+This will take many hours to run.
 """
-import numpy as np
 import bilby
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment that we're
 # going to inject the signal into
-duration = 4.
-sampling_frequency = 2048.
+duration = 4.0
+sampling_frequency = 1024.0
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'full_15_parameters'
+outdir = "outdir"
+label = "full_15_parameters"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -25,75 +27,120 @@ np.random.seed(88170235)
 # parameters, including masses of the two black holes (mass_1, mass_2),
 # spins of both black holes (a, tilt, phi), etc.
 injection_parameters = dict(
-    mass_1=36., mass_2=29., a_1=0.4, a_2=0.3, tilt_1=0.5, tilt_2=1.0,
-    phi_12=1.7, phi_jl=0.3, luminosity_distance=2000., theta_jn=0.4, psi=2.659,
-    phase=1.3, geocent_time=1126259642.413, ra=1.375, dec=-1.2108)
+    mass_1=36.0,
+    mass_2=29.0,
+    a_1=0.4,
+    a_2=0.3,
+    tilt_1=0.5,
+    tilt_2=1.0,
+    phi_12=1.7,
+    phi_jl=0.3,
+    luminosity_distance=2000.0,
+    theta_jn=0.4,
+    psi=2.659,
+    phase=1.3,
+    geocent_time=1126259642.413,
+    ra=1.375,
+    dec=-1.2108,
+)
 
 # Fixed arguments passed into the source model
-waveform_arguments = dict(waveform_approximant='IMRPhenomPv2',
-                          reference_frequency=50., minimum_frequency=20.)
+waveform_arguments = dict(
+    waveform_approximant="IMRPhenomXPHM",
+    reference_frequency=50.0,
+    minimum_frequency=20.0,
+)
 
 # Create the waveform_generator using a LAL BinaryBlackHole source function
 # the generator will convert all the parameters
 waveform_generator = bilby.gw.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
     parameter_conversion=bilby.gw.conversion.convert_to_lal_binary_black_hole_parameters,
-    waveform_arguments=waveform_arguments)
+    waveform_arguments=waveform_arguments,
+)
 
 # Set up interferometers.  In this case we'll use two interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1). These default to their design
 # sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 2,
+)
 
+ifos.inject_signal(
+    waveform_generator=waveform_generator, parameters=injection_parameters
+)
 
 # For this analysis, we implement the standard BBH priors defined, except for
 # the definition of the time prior, which is defined as uniform about the
 # injected value.
-# Furthermore, we decide to sample in chirp mass and mass ratio, due to the
-# preferred shape for the associated posterior distributions.
+# We change the mass boundaries to be more targeted for the source we
+# injected.
+# We define priors in the time at the Hanford interferometer and two
+# parameters (zenith, azimuth) defining the sky position wrt the two
+# interferometers.
 priors = bilby.gw.prior.BBHPriorDict()
-priors.pop('mass_1')
-priors.pop('mass_2')
 
-priors['chirp_mass'] = bilby.prior.Uniform(
-    name='chirp_mass', latex_label='$M$', minimum=10.0, maximum=100.0,
-    unit='$M_{\\odot}$')
-
-priors['mass_ratio'] = bilby.prior.Uniform(
-    name='mass_ratio', latex_label='$q$', minimum=0.5, maximum=1.0)
-
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    minimum=injection_parameters['geocent_time'] - 0.1,
-    maximum=injection_parameters['geocent_time'] + 0.1,
-    name='geocent_time', latex_label='$t_c$', unit='$s$')
+time_delay = ifos[0].time_delay_from_geocenter(
+    injection_parameters["ra"],
+    injection_parameters["dec"],
+    injection_parameters["geocent_time"],
+)
+priors["H1_time"] = bilby.core.prior.Uniform(
+    minimum=injection_parameters["geocent_time"] + time_delay - 0.1,
+    maximum=injection_parameters["geocent_time"] + time_delay + 0.1,
+    name="H1_time",
+    latex_label="$t_H$",
+    unit="$s$",
+)
+del priors["ra"], priors["dec"]
+priors["zenith"] = bilby.core.prior.Sine(latex_label="$\\kappa$")
+priors["azimuth"] = bilby.core.prior.Uniform(
+    minimum=0, maximum=2 * np.pi, latex_label="$\\epsilon$", boundary="periodic"
+)
 
 # Initialise the likelihood by passing in the interferometer data (ifos) and
 # the waveoform generator, as well the priors.
-# The explicit time, distance, and phase marginalizations are turned on to
-# improve convergence, and the parameters are recovered by the conversion
-# function.
+# The explicit distance marginalization is turned on to improve
+# convergence, and the posterior is recovered by the conversion function.
 likelihood = bilby.gw.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=waveform_generator, priors=priors,
-    distance_marginalization=True, phase_marginalization=True, time_marginalization=True)
-
-# Run sampler. In this case we're going to use the `cpnest` sampler
-# Note that the maxmcmc parameter is increased so that between each iteration of
-# the nested sampler approach, the walkers will move further using an mcmc
-# approach, searching the full parameter space.
-# The conversion function will determine the distance, phase and coalescence
-# time posteriors in post processing.
+    interferometers=ifos,
+    waveform_generator=waveform_generator,
+    priors=priors,
+    distance_marginalization=True,
+    phase_marginalization=False,
+    time_marginalization=False,
+    reference_frame="H1L1",
+    time_reference="H1",
+)
+
+# Run sampler. In this case we're going to use the `dynesty` sampler
+# Note that the `walks`, `nact`, and `maxmcmc` parameter are specified
+# to ensure sufficient convergence of the analysis.
+# We set `npool=4` to parallelize the analysis over four cores.
+# The conversion function will determine the distance posterior in post processing
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='cpnest', npoints=2000,
-    injection_parameters=injection_parameters, outdir=outdir,
-    label=label, maxmcmc=2000,
-    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    nlive=1000,
+    walks=20,
+    nact=50,
+    maxmcmc=2000,
+    npool=4,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label=label,
+    conversion_function=bilby.gw.conversion.generate_all_bbh_parameters,
+    result_class=bilby.gw.result.CBCResult,
+)
+
+# Plot the inferred waveform superposed on the actual data.
+result.plot_waveform_posterior(n_samples=1000)
 
 # Make a corner plot.
 result.plot_corner()
diff --git a/examples/gw_examples/injection_examples/using_gwin.py b/examples/gw_examples/injection_examples/using_gwin.py
deleted file mode 100644
index de68a82886ed8f262d0f5bef5ddf98629be13b8d..0000000000000000000000000000000000000000
--- a/examples/gw_examples/injection_examples/using_gwin.py
+++ /dev/null
@@ -1,92 +0,0 @@
-#!/usr/bin/env python
-"""
-An example of how to use bilby with `gwin` (https://github.com/gwastro/gwin) to
-perform CBC parameter estimation.
-
-To run this example, it is sufficient to use the pip-installable pycbc package,
-but the source installation of gwin. You can install this by cloning the
-repository (https://github.com/gwastro/gwin) and running
-
-$ python setup.py install
-
-A practical difference between gwin and bilby is that while fixed parameters
-are specified via the prior in bilby, in gwin, these are fixed at instantiation
-of the model. So, in the following, we only create priors for the parameters
-to be searched over.
-
-"""
-import numpy as np
-import bilby
-
-import gwin
-from pycbc import psd as pypsd
-from pycbc.waveform.generator import (FDomainDetFrameGenerator,
-                                      FDomainCBCGenerator)
-
-label = 'using_gwin'
-
-# Search priors
-priors = dict()
-priors['distance'] = bilby.core.prior.Uniform(500, 2000, 'distance')
-priors['polarization'] = bilby.core.prior.Uniform(0, np.pi, 'theta_jn')
-
-# Data variables
-seglen = 4
-sample_rate = 2048
-N = seglen * sample_rate / 2 + 1
-fmin = 30.
-
-# Injected signal variables
-injection_parameters = dict(mass1=38.6, mass2=29.3, spin1z=0, spin2z=0,
-                            tc=0, ra=3.1, dec=1.37, polarization=2.76,
-                            distance=1500)
-
-# These lines figure out which parameters are to be searched over
-variable_parameters = priors.keys()
-fixed_parameters = injection_parameters.copy()
-for key in priors:
-    fixed_parameters.pop(key)
-
-# These lines generate the `model` object - see
-# https://gwin.readthedocs.io/en/latest/api/gwin.models.gaussian_noise.html
-generator = FDomainDetFrameGenerator(
-    FDomainCBCGenerator, 0.,
-    variable_args=variable_parameters, detectors=['H1', 'L1'],
-    delta_f=1. / seglen, f_lower=fmin,
-    approximant='IMRPhenomPv2', **fixed_parameters)
-signal = generator.generate(**injection_parameters)
-psd = pypsd.aLIGOZeroDetHighPower(int(N), 1. / seglen, 20.)
-psds = {'H1': psd, 'L1': psd}
-model = gwin.models.GaussianNoise(
-    variable_parameters, signal, generator, fmin, psds=psds)
-model.update(**injection_parameters)
-
-
-# This create a dummy class to convert the model into a bilby.likelihood object
-class GWINLikelihood(bilby.core.likelihood.Likelihood):
-
-    def __init__(self, model):
-        """ A likelihood to wrap around GWIN model objects
-
-        Parameters
-        ----------
-        model: gwin.model.GaussianNoise
-            A gwin model
-
-        """
-        self.model = model
-        self.parameters = {x: None for x in self.model.variable_params}
-
-    def log_likelihood(self):
-        self.model.update(**self.parameters)
-        return self.model.loglikelihood
-
-
-likelihood = GWINLikelihood(model)
-
-
-# Now run the inference
-result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
-    label=label)
-result.plot_corner()
diff --git a/examples/gw_examples/supernova_example/supernova_example.py b/examples/gw_examples/supernova_example/supernova_example.py
index b5de9f40775a91f27dce75a2a3fe549d6ca4a2c8..1d79fc650bcb6d63d5734d8715bccab1dc67ef0a 100644
--- a/examples/gw_examples/supernova_example/supernova_example.py
+++ b/examples/gw_examples/supernova_example/supernova_example.py
@@ -6,18 +6,22 @@ supernova injected signal.  Signal model is made by applying PCA to a set of
 supernova waveforms. The first few PCs are then linearly combined with a scale
 factor. (See https://arxiv.org/pdf/1202.3256.pdf)
 
+For this example we use `PyMultiNest`, this can be installed using
+
+conda install -c conda-forge pymultinest
 """
-import numpy as np
 import bilby
+import numpy as np
 
 # Set the duration and sampling frequency of the data segment that we're going
-# to inject the signal into
-duration = 3.
-sampling_frequency = 4096.
+# to inject the signal into.
+# These are fixed by the resolution in the injection file that we are using.
+duration = 3
+sampling_frequency = 4096
 
 # Specify the output directory and the name of the simulation.
-outdir = 'outdir'
-label = 'supernova'
+outdir = "outdir"
+label = "supernova"
 bilby.core.utils.setup_logger(outdir=outdir, label=label)
 
 # Set up a random seed for result reproducibility.  This is optional!
@@ -27,70 +31,92 @@ np.random.seed(170801)
 # of parameters that includes all of the different waveform parameters. It will
 # read in a signal to inject from a txt file.
 
-injection_parameters = dict(file_path='MuellerL15_example_inj.txt',
-                            luminosity_distance=7.0, ra=4.6499,
-                            dec=-0.5063, geocent_time=1126259642.413,
-                            psi=2.659)
+injection_parameters = dict(
+    luminosity_distance=7.0,
+    ra=4.6499,
+    dec=-0.5063,
+    geocent_time=1126259642.413,
+    psi=2.659,
+)
 
 # Create the waveform_generator using a supernova source function
 waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.supernova,
-    parameters=injection_parameters)
+    parameters=injection_parameters,
+    parameter_conversion=lambda parameters: (parameters, list()),
+    waveform_arguments=dict(file_path="MuellerL15_example_inj.txt"),
+)
 
 # Set up interferometers.  In this case we'll use three interferometers
 # (LIGO-Hanford (H1), LIGO-Livingston (L1), and Virgo (V1)).  These default to
 # their design sensitivity
-ifos = bilby.gw.detector.InterferometerList(['H1', 'L1'])
+ifos = bilby.gw.detector.InterferometerList(["H1", "L1"])
 ifos.set_strain_data_from_power_spectral_densities(
-    sampling_frequency=sampling_frequency, duration=duration,
-    start_time=injection_parameters['geocent_time'] - 3)
-ifos.inject_signal(waveform_generator=waveform_generator,
-                   parameters=injection_parameters)
+    sampling_frequency=sampling_frequency,
+    duration=duration,
+    start_time=injection_parameters["geocent_time"] - 1.5,
+)
+ifos.inject_signal(
+    waveform_generator=waveform_generator,
+    parameters=injection_parameters,
+    raise_error=False,
+)
 
 # read in from a file the PCs used to create the signal model.
-realPCs = np.loadtxt('SupernovaRealPCs.txt')
-imagPCs = np.loadtxt('SupernovaImagPCs.txt')
+realPCs = np.genfromtxt("SupernovaRealPCs.txt")
+imagPCs = np.genfromtxt("SupernovaImagPCs.txt")
 
 # Now we make another waveform_generator because the signal model is
 # not the same as the injection in this case.
-simulation_parameters = dict(
-    realPCs=realPCs, imagPCs=imagPCs)
+simulation_parameters = dict(realPCs=realPCs, imagPCs=imagPCs)
 
 search_waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
-    duration=duration, sampling_frequency=sampling_frequency,
+    duration=duration,
+    sampling_frequency=sampling_frequency,
     frequency_domain_source_model=bilby.gw.source.supernova_pca_model,
-    waveform_arguments=simulation_parameters)
+    waveform_arguments=simulation_parameters,
+)
 
 # Set up prior
-priors = dict()
-for key in ['psi', 'geocent_time']:
+priors = bilby.core.prior.PriorDict()
+for key in ["psi", "geocent_time"]:
     priors[key] = injection_parameters[key]
-priors['luminosity_distance'] = bilby.core.prior.Uniform(
-    2, 20, 'luminosity_distance', unit='$kpc$')
-priors['pc_coeff1'] = bilby.core.prior.Uniform(-1, 1, 'pc_coeff1')
-priors['pc_coeff2'] = bilby.core.prior.Uniform(-1, 1, 'pc_coeff2')
-priors['pc_coeff3'] = bilby.core.prior.Uniform(-1, 1, 'pc_coeff3')
-priors['pc_coeff4'] = bilby.core.prior.Uniform(-1, 1, 'pc_coeff4')
-priors['pc_coeff5'] = bilby.core.prior.Uniform(-1, 1, 'pc_coeff5')
-priors['ra'] = bilby.core.prior.Uniform(minimum=0, maximum=2 * np.pi,
-                                        name='ra')
-priors['dec'] = bilby.core.prior.Sine(name='dec')
-priors['geocent_time'] = bilby.core.prior.Uniform(
-    injection_parameters['geocent_time'] - 1,
-    injection_parameters['geocent_time'] + 1,
-    'geocent_time', unit='$s$')
+priors["luminosity_distance"] = bilby.core.prior.Uniform(
+    2, 20, "luminosity_distance", unit="$kpc$"
+)
+priors["pc_coeff1"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff1")
+priors["pc_coeff2"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff2")
+priors["pc_coeff3"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff3")
+priors["pc_coeff4"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff4")
+priors["pc_coeff5"] = bilby.core.prior.Uniform(-1, 1, "pc_coeff5")
+priors["ra"] = bilby.core.prior.Uniform(
+    minimum=0, maximum=2 * np.pi, name="ra", boundary="periodic"
+)
+priors["dec"] = bilby.core.prior.Sine(name="dec")
+priors["geocent_time"] = bilby.core.prior.Uniform(
+    injection_parameters["geocent_time"] - 1,
+    injection_parameters["geocent_time"] + 1,
+    "geocent_time",
+    unit="$s$",
+)
 
 # Initialise the likelihood by passing in the interferometer data (IFOs) and
 # the waveoform generator
 likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
-    interferometers=ifos, waveform_generator=search_waveform_generator)
+    interferometers=ifos, waveform_generator=search_waveform_generator
+)
 
 # Run sampler.
 result = bilby.run_sampler(
-    likelihood=likelihood, priors=priors, sampler='dynesty', npoints=500,
-    outdir=outdir, label=label)
+    likelihood=likelihood,
+    priors=priors,
+    sampler="pymultinest",
+    npoints=500,
+    outdir=outdir,
+    label=label,
+)
 
 # make some plots of the outputs
 result.plot_corner()
-print(result)
diff --git a/examples/tutorials/fitting_with_x_and_y_errors.ipynb b/examples/tutorials/fitting_with_x_and_y_errors.ipynb
index 6cda25778ac0c0ec54c1b0dd43677e1d94cd72de..56d2969053d44698519b5737f86d1c039ccf19ad 100644
--- a/examples/tutorials/fitting_with_x_and_y_errors.ipynb
+++ b/examples/tutorials/fitting_with_x_and_y_errors.ipynb
@@ -6,20 +6,7 @@
    "source": [
     "# Fitting a model to data with both x and y errors with `Bilby`\n",
     "\n",
-    "Usually when we fit a model to data with a Gaussian Likelihood we assume that we know x values exactly. This is almost never the case. Here we show how to fit a model with errors in both x and y.\n",
-    "\n",
-    "Since we are using a very simple model we will use the `nestle` sampler.\n",
-    "This can be installed using\n",
-    "\n",
-    "```console\n",
-    "$ conda install -c conda-forge nestle\n",
-    "```\n",
-    "\n",
-    "or\n",
-    "\n",
-    "```console\n",
-    "$ pip install nestle\n",
-    "```"
+    "Usually when we fit a model to data with a Gaussian Likelihood we assume that we know x values exactly. This is almost never the case. Here we show how to fit a model with errors in both x and y."
    ]
   },
   {
@@ -62,10 +49,10 @@
     "    xtrue = np.linspace(0, 100, points)\n",
     "    ytrue = model(x=xtrue, m=m, c=c)\n",
     "\n",
-    "    xerr = xerr * np.random.randn(points)\n",
-    "    yerr = yerr * np.random.randn(points)\n",
-    "    xobs = xtrue + xerr\n",
-    "    yobs = ytrue + yerr\n",
+    "    xerr_vals = xerr * np.random.randn(points)\n",
+    "    yerr_vals = yerr * np.random.randn(points)\n",
+    "    xobs = xtrue + xerr_vals\n",
+    "    yobs = ytrue + yerr_vals\n",
     "\n",
     "    plt.errorbar(xobs, yobs, xerr=xerr, yerr=yerr, fmt=\"x\")\n",
     "    plt.errorbar(xtrue, ytrue, yerr=yerr, color=\"black\", alpha=0.5)\n",
@@ -108,7 +95,7 @@
     "    m=bilby.core.prior.Uniform(0, 30, \"m\"), c=bilby.core.prior.Uniform(0, 30, \"c\")\n",
     ")\n",
     "\n",
-    "sampler_kwargs = dict(priors=priors, sampler=\"nestle\", nlive=1000, outdir=\"outdir\", verbose=False)"
+    "sampler_kwargs = dict(priors=priors, sampler=\"bilby_mcmc\", nsamples=1000, printdt=5, outdir=\"outdir\", verbose=False, clean=True)"
    ]
   },
   {
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..41b21bf9462e969b24d03bfea6bd0ef33f5e474a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,10 @@
+[build-system]
+requires = [
+    "setuptools>=42",
+    "setuptools_scm[toml]>=3.4.3",
+    "wheel",
+]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools_scm]
+write_to = "bilby/_version.py"
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 70e457fcaca3e0d459bcf524f47300af8f1d9e93..b69a3c7ce373c9df519b9d75a35b3e7260bf43cc 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,4 @@
+bilby.cython>=0.3.0
 dynesty<1.1
 emcee
 corner
diff --git a/sampler_requirements.txt b/sampler_requirements.txt
index d60093c8b1b3e2abd5af5397f02f1986655f5e04..64d2c8c50bdf1d44292952d97f780423c3c6c213 100644
--- a/sampler_requirements.txt
+++ b/sampler_requirements.txt
@@ -3,7 +3,7 @@ dynesty
 emcee
 nestle
 ptemcee
-pymc3>=3.6
+pymc>=4.0.0
 pymultinest
 kombine
 ultranest>=3.0.0
diff --git a/setup.py b/setup.py
index 1bb6b864940ecea1ba0704dc000960936050b19f..e8b575d2949fb429fa569af8ee7a575815c2e4d3 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,6 @@
 #!/usr/bin/env python
 
 from setuptools import setup
-import subprocess
 import sys
 import os
 
@@ -10,44 +9,6 @@ if python_version < (3, 8):
     sys.exit("Python < 3.8 is not supported, aborting setup")
 
 
-def write_version_file(version):
-    """Writes a file with version information to be used at run time
-
-    Parameters
-    ----------
-    version: str
-        A string containing the current version information
-
-    Returns
-    -------
-    version_file: str
-        A path to the version file
-
-    """
-    try:
-        git_log = subprocess.check_output(
-            ["git", "log", "-1", "--pretty=%h %ai"]
-        ).decode("utf-8")
-        git_diff = (
-            subprocess.check_output(["git", "diff", "."])
-            + subprocess.check_output(["git", "diff", "--cached", "."])
-        ).decode("utf-8")
-        if git_diff == "":
-            git_status = "(CLEAN) " + git_log
-        else:
-            git_status = "(UNCLEAN) " + git_log
-    except Exception as e:
-        print("Unable to obtain git version information, exception: {}".format(e))
-        git_status = "release"
-
-    version_file = ".version"
-    if os.path.isfile(version_file) is False:
-        with open("bilby/" + version_file, "w+") as f:
-            f.write("{}: {}".format(version, git_status))
-
-    return version_file
-
-
 def get_long_description():
     """Finds the README and reads in the description"""
     here = os.path.abspath(os.path.dirname(__file__))
@@ -73,8 +34,6 @@ def readfile(filename):
     return filecontents
 
 
-VERSION = '1.1.5'
-version_file = write_version_file(VERSION)
 long_description = get_long_description()
 
 setup(
@@ -86,7 +45,6 @@ setup(
     author="Greg Ashton, Moritz Huebner, Paul Lasky, Colm Talbot",
     author_email="paul.lasky@monash.edu",
     license="MIT",
-    version=VERSION,
     packages=[
         "bilby",
         "bilby.bilby_mcmc",
@@ -107,7 +65,6 @@ setup(
         "bilby.gw": ["prior_files/*"],
         "bilby.gw.detector": ["noise_curves/*.txt", "detectors/*"],
         "bilby.gw.eos": ["eos_tables/*.dat"],
-        "bilby": [version_file],
     },
     python_requires=">=3.8",
     install_requires=get_requirements(),
diff --git a/test/bilby_mcmc/test_sampler.py b/test/bilby_mcmc/test_sampler.py
index aa52967da16cbfbf754279deed3c83139632d1e7..746eb1a9e1150732e93d5c31664751040e7b639c 100644
--- a/test/bilby_mcmc/test_sampler.py
+++ b/test/bilby_mcmc/test_sampler.py
@@ -3,7 +3,7 @@ import shutil
 import unittest
 
 import bilby
-from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler, _initialize_global_variables
+from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler
 from bilby.bilby_mcmc.utils import ConvergenceInputs
 from bilby.core.sampler.base_sampler import SamplerError
 import numpy as np
@@ -44,7 +44,12 @@ class TestBilbyMCMCSampler(unittest.TestCase):
         search_parameter_keys = ['m', 'c']
         use_ratio = False
 
-        _initialize_global_variables(likelihood, priors, search_parameter_keys, use_ratio)
+        bilby.core.sampler.base_sampler._initialize_global_variables(
+            likelihood,
+            priors,
+            search_parameter_keys,
+            use_ratio,
+        )
 
     def tearDown(self):
         if os.path.isdir(self.outdir):
diff --git a/test/check_author_list.py b/test/check_author_list.py
index 98352318a503eb588b566d6b6b368e24caa5f1d9..752f6d2debe715e94ddf6958b544c82a7d730f52 100644
--- a/test/check_author_list.py
+++ b/test/check_author_list.py
@@ -26,7 +26,7 @@ def remove_accents(raw_text):
     raw_text = re.sub(u"[ß]", 'ss', raw_text)
     raw_text = re.sub(u"[ñ]", 'n', raw_text)
 
-    return(raw_text)
+    return raw_text
 
 
 fail_test = False
diff --git a/test/core/prior/conditional_test.py b/test/core/prior/conditional_test.py
index 3f7e8873169572ab0d6a096c6ddc8ad4f67078d2..94a869936fc6cd71cc3647e51a5a372e08ed4544 100644
--- a/test/core/prior/conditional_test.py
+++ b/test/core/prior/conditional_test.py
@@ -1,7 +1,10 @@
+import os
+import shutil
 import unittest
 from unittest import mock
 
 import numpy as np
+import pandas as pd
 
 import bilby
 
@@ -320,6 +323,43 @@ class TestConditionalPriorDict(unittest.TestCase):
             expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
         self.assertListEqual(expected, res)
 
+    def test_rescale_with_joint_prior(self):
+        """
+        Add a joint prior into the conditional prior dictionary and check that
+        the returned list is flat.
+        """
+
+        # set multivariate Gaussian distribution
+        names = ["mvgvar_0", "mvgvar_1"]
+        mu = [[0.79, -0.83]]
+        cov = [[[0.03, 0.], [0., 0.04]]]
+        mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)
+
+        priordict = bilby.core.prior.ConditionalPriorDict(
+            dict(
+                var_3=self.prior_3,
+                var_2=self.prior_2,
+                var_0=self.prior_0,
+                var_1=self.prior_1,
+                mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"),
+                mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"),
+            )
+        )
+
+        ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
+        keys = list(self.test_sample.keys()) + names
+        res = priordict.rescale(keys=keys, theta=ref_variables)
+
+        self.assertIsInstance(res, list)
+        self.assertEqual(np.shape(res), (6,))
+        self.assertListEqual([isinstance(r, float) for r in res], 6 * [True])
+
+        # check conditional values are still as expected
+        expected = [self.test_sample["var_0"]]
+        for ii in range(1, 4):
+            expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
+        self.assertListEqual(expected, res[0:4])
+
     def test_cdf(self):
         """
         Test that the CDF method is the inverse of the rescale method.
@@ -371,5 +411,33 @@ class TestConditionalPriorDict(unittest.TestCase):
         print(res)
 
 
+class TestDirichletPrior(unittest.TestCase):
+
+    def setUp(self):
+        self.priors = bilby.core.prior.DirichletPriorDict(5)
+
+    def tearDown(self):
+        if os.path.isdir("priors"):
+            shutil.rmtree("priors")
+
+    def test_samples_sum_to_less_than_one(self):
+        """
+        Test that the samples sum to less than one as required for the
+        Dirichlet distribution.
+        """
+        samples = pd.DataFrame(self.priors.sample(10000)).values
+        self.assertLess(max(np.sum(samples, axis=1)), 1)
+
+    def test_read_write_file(self):
+        self.priors.to_file(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict(filename="priors/test.prior")
+        self.assertEqual(self.priors, test)
+
+    def test_read_write_json(self):
+        self.priors.to_json(outdir="priors", label="test")
+        test = bilby.core.prior.PriorDict.from_json(filename="priors/test_prior.json")
+        self.assertEqual(self.priors, test)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/core/prior/joint_test.py b/test/core/prior/joint_test.py
index 20b88a69e6e749f6ba8ec9091132103dd2a7ed48..ebadfcfaeb96ed1cd017a212097c8ff65f5c477c 100644
--- a/test/core/prior/joint_test.py
+++ b/test/core/prior/joint_test.py
@@ -39,5 +39,61 @@ MultivariateGaussianDist(
                 )
 
 
+class TestMultivariateGaussianDistParameterScales(unittest.TestCase):
+    def _test_mvg_ln_prob_diff_expected(self, mvg, mus, sigmas, corrcoefs):
+        # the columns of the Cholesky decompsition give the directions along which
+        # the multivariate Gaussian PDF will decrease by exact differences per unit
+        # sigma; test that these are as expected
+        ln_prob_mus = mvg.ln_prob(mus)
+        d = np.linalg.cholesky(corrcoefs)
+        for i in np.ndindex(4, 4, 4):
+            ln_prob_mus_sigmas_d_i = mvg.ln_prob(mus + sigmas * (d @ i))
+            diff_ln_prob = ln_prob_mus - ln_prob_mus_sigmas_d_i
+            diff_ln_prob_expected = 0.5 * np.sum(np.array(i)**2)
+            self.assertTrue(
+                np.allclose(diff_ln_prob, diff_ln_prob_expected)
+            )
+
+    def test_mvg_unit_scales(self):
+        # test using order-unity standard deviations and correlations
+        sigmas = 0.3 * np.ones(3)
+        corrcoefs = np.identity(3)
+        mus = np.array([3, 1, 2])
+        mvg = bilby.core.prior.MultivariateGaussianDist(
+            names=['a', 'b', 'c'],
+            mus=mus,
+            sigmas=sigmas,
+            corrcoefs=corrcoefs,
+        )
+
+        self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs)
+
+    def test_mvg_cw_scales(self):
+        # test using standard deviations and correlations from the
+        # inverse Fisher information matrix for the frequency/spindown
+        # parameters of a continuous wave signal
+        T = 365.25 * 86400
+        rho = 10
+        sigmas = np.array([
+            5 * np.sqrt(3) / (2 * np.pi * T * rho),
+            6 * np.sqrt(5) / (np.pi * T**2 * rho),
+            60 * np.sqrt(7) / (np.pi * T**3 * rho)
+        ])
+        corrcoefs = np.identity(3)
+        corrcoefs[0, 2] = corrcoefs[2, 0] = -np.sqrt(21) / 5
+
+        # test MultivariateGaussianDist() can handle parameters with very different scales:
+        # - f ~ 100, fd ~ 1/T, fdd ~ 1/T^2
+        mus = [123.4, -5.6e-8, 9e-18]
+        mvg = bilby.core.prior.MultivariateGaussianDist(
+            names=["f", "fd", "fdd"],
+            mus=mus,
+            sigmas=sigmas,
+            corrcoefs=corrcoefs,
+        )
+
+        self._test_mvg_ln_prob_diff_expected(mvg, mus, sigmas, corrcoefs)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/core/prior/prior_test.py b/test/core/prior/prior_test.py
index 9241945433cca240af4a4dc740b537d24aa59c83..3fd77af938a4abf8fad832a669e89f797ea87f5f 100644
--- a/test/core/prior/prior_test.py
+++ b/test/core/prior/prior_test.py
@@ -673,7 +673,7 @@ class TestPriorClasses(unittest.TestCase):
                     "bilby.core.prior.MultivariateGaussianDist",
                 )
             elif isinstance(prior, bilby.gw.prior.HealPixPrior):
-                repr_prior_string = "bilby.gw.prior." + repr(prior)
+                repr_prior_string = repr(prior)
                 repr_prior_string = repr_prior_string.replace(
                     "HealPixMapPriorDist", "bilby.gw.prior.HealPixMapPriorDist"
                 )
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index 30be5e2ba542205d4cdbef10c6f8e9d681bcbbf1..4856a9e7df4922bf94fe26746972bb1864774535 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -1,7 +1,9 @@
 import copy
 import os
+import shutil
 import unittest
 from unittest.mock import MagicMock
+from parameterized import parameterized
 
 import numpy as np
 
@@ -60,6 +62,15 @@ class TestSampler(unittest.TestCase):
     def test_label(self):
         self.assertEqual(self.sampler.label, "label")
 
+    @parameterized.expand(["sampling_seed", "seed", "random_seed"])
+    def test_translate_kwargs(self, key):
+        self.sampler.sampling_seed_key = key
+        for k in self.sampler.sampling_seed_equiv_kwargs:
+            kwargs = {k: 1234}
+            updated_kwargs = self.sampler._translate_kwargs(kwargs)
+            self.assertDictEqual(updated_kwargs, {key: 1234})
+        self.sampler.sampling_seed_key = None
+
     def test_prior_transform_transforms_search_parameter_keys(self):
         self.sampler.prior_transform([0])
         expected_prior = prior.Uniform(0, 1)
@@ -102,22 +113,100 @@ class TestSampler(unittest.TestCase):
         self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None)
 
     def test_bad_value_np_abs_nan(self):
-        self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.abs(np.nan), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_abs_nan(self):
-        self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=abs(np.nan), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_pos_inf(self):
         self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None)
 
     def test_bad_value_neg_inf(self):
-        self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=-np.inf, warning=False, theta=None, label=None
+        )
 
     def test_bad_value_pos_inf_nan_to_num(self):
-        self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.nan_to_num(np.inf), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_neg_inf_nan_to_num(self):
-        self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None
+        )
+
+
+samplers = [
+    "bilby_mcmc",
+    "dynamic_dynesty",
+    "dynesty",
+    "emcee",
+    "kombine",
+    "ptemcee",
+    "zeus",
+]
+
+
+class GenericSamplerTest(unittest.TestCase):
+    def setUp(self):
+        self.likelihood = bilby.core.likelihood.Likelihood(dict())
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
+        )
+
+    def tearDown(self):
+        if os.path.isdir("outdir"):
+            shutil.rmtree("outdir")
+
+    @parameterized.expand(samplers)
+    def test_pool_creates_properly_no_pool(self, sampler_name):
+        sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name](
+            self.likelihood, self.priors
+        )
+        sampler._setup_pool()
+        if sampler_name == "kombine":
+            from kombine import SerialPool
+
+            self.assertIsInstance(sampler.pool, SerialPool)
+            pass
+        else:
+            self.assertIsNone(sampler.pool)
+
+    @parameterized.expand(samplers)
+    def test_pool_creates_properly_pool(self, sampler):
+        sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler](
+            self.likelihood, self.priors, npool=2
+        )
+        sampler._setup_pool()
+        if hasattr(sampler, "setup_sampler"):
+            sampler.setup_sampler()
+        self.assertEqual(sampler.pool._processes, 2)
+        sampler._close_pool()
+
+
+class ReorderLikelihoodsTest(unittest.TestCase):
+    def setUp(self):
+        self.unsorted_ln_likelihoods = np.array([1, 5, 2, 5, 1])
+        self.unsorted_samples = np.array([[0, 1], [1, 1], [1, 0], [0, 0], [0, 1]])
+        self.sorted_samples = np.array([[0, 1], [0, 1], [1, 0], [1, 1], [0, 0]])
+        self.sorted_ln_likelihoods = np.array([1, 1, 2, 5, 5])
+
+    def tearDown(self):
+        pass
+
+    def test_ordering(self):
+        func = bilby.core.sampler.base_sampler.NestedSampler.reorder_loglikelihoods
+        sorted_ln_likelihoods = func(
+            self.unsorted_ln_likelihoods, self.unsorted_samples, self.sorted_samples
+        )
+        self.assertTrue(
+            np.array_equal(sorted_ln_likelihoods, self.sorted_ln_likelihoods)
+        )
 
 
 if __name__ == "__main__":
diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py
index 86b03fb38e74afb5ce75c06b5fbd91add3d7f49e..7f6ec21a8a5d26b606e8c6f8aa3cae3ede905ca2 100644
--- a/test/core/sampler/nessai_test.py
+++ b/test/core/sampler/nessai_test.py
@@ -20,9 +20,11 @@ class TestNessai(unittest.TestCase):
             use_ratio=False,
             plot=False,
             skip_import_verification=True,
+            sampling_seed=150914,
         )
         self.expected = self.sampler.default_kwargs
         self.expected['output'] = 'outdir/label_nessai/'
+        self.expected['seed'] = 150914
 
     def tearDown(self):
         del self.likelihood
@@ -54,14 +56,7 @@ class TestNessai(unittest.TestCase):
             self.assertDictEqual(expected, self.sampler.kwargs)
 
     def test_translate_kwargs_seed(self):
-        expected = self.expected.copy()
-        expected["seed"] = 150914
-        for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs:
-            new_kwargs = self.sampler.kwargs.copy()
-            del new_kwargs["seed"]
-            new_kwargs[equiv] = 150914
-            self.sampler.kwargs = new_kwargs
-            self.assertDictEqual(expected, self.sampler.kwargs)
+        assert self.expected["seed"] == 150914
 
     def test_npool_max_threads(self):
         expected = self.expected.copy()
diff --git a/test/core/sampler/pymc3_test.py b/test/core/sampler/pymc_test.py
similarity index 89%
rename from test/core/sampler/pymc3_test.py
rename to test/core/sampler/pymc_test.py
index b3bb758d3270602debdda62f12e6279f9c48d534..c904e1fd880d2fc2ad6c4864ad1dac68424917c2 100644
--- a/test/core/sampler/pymc3_test.py
+++ b/test/core/sampler/pymc_test.py
@@ -1,22 +1,16 @@
 import unittest
 from unittest.mock import MagicMock
 
-import pytest
-
 import bilby
 
 
-@pytest.mark.xfail(
-    raises=AttributeError,
-    reason="Dependency issue with pymc3 causes attribute error on import",
-)
-class TestPyMC3(unittest.TestCase):
+class TestPyMC(unittest.TestCase):
     def setUp(self):
         self.likelihood = MagicMock()
         self.priors = bilby.core.prior.PriorDict(
             dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
         )
-        self.sampler = bilby.core.sampler.Pymc3(
+        self.sampler = bilby.core.sampler.Pymc(
             self.likelihood,
             self.priors,
             outdir="outdir",
@@ -37,7 +31,7 @@ class TestPyMC3(unittest.TestCase):
             step=None,
             init="auto",
             n_init=200000,
-            start=None,
+            initvals=None,
             trace=None,
             chain_idx=0,
             chains=2,
@@ -61,7 +55,7 @@ class TestPyMC3(unittest.TestCase):
             step=None,
             init="auto",
             n_init=200000,
-            start=None,
+            initvals=None,
             trace=None,
             chain_idx=0,
             chains=2,
diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py
index dc578cd71932c877f0de8414361781cc86837789..be22c1a1f50b8d304000fcb8d0e4816e57c9c1b9 100644
--- a/test/core/sampler/ultranest_test.py
+++ b/test/core/sampler/ultranest_test.py
@@ -28,7 +28,7 @@ class TestUltranest(unittest.TestCase):
 
     def test_default_kwargs(self):
         expected = dict(
-            resume=True,
+            resume="overwrite",
             show_status=True,
             num_live_points=None,
             wrapped_params=None,
@@ -63,7 +63,7 @@ class TestUltranest(unittest.TestCase):
 
     def test_translate_kwargs(self):
         expected = dict(
-            resume=True,
+            resume="overwrite",
             show_status=True,
             num_live_points=123,
             wrapped_params=None,
diff --git a/test/gw/conversion_test.py b/test/gw/conversion_test.py
index 54dd064f08e956587debb4951dcf08270efaa612..dbe2ae75a5f65a1c9af6eb13048d3ae4d399f4e7 100644
--- a/test/gw/conversion_test.py
+++ b/test/gw/conversion_test.py
@@ -3,6 +3,7 @@ import unittest
 import numpy as np
 import pandas as pd
 
+
 import bilby
 from bilby.gw import conversion
 
@@ -99,6 +100,13 @@ class TestBasicConversions(unittest.TestCase):
         )
         self.assertAlmostEqual(self.total_mass, total_mass)
 
+    def test_chirp_mass_and_mass_ratio_to_component_masses(self):
+        mass_1, mass_2 = \
+            conversion.chirp_mass_and_mass_ratio_to_component_masses(
+                self.chirp_mass, self.mass_ratio)
+        self.assertAlmostEqual(self.mass_1, mass_1)
+        self.assertAlmostEqual(self.mass_2, mass_2)
+
     def test_component_masses_to_chirp_mass(self):
         chirp_mass = conversion.component_masses_to_chirp_mass(self.mass_1, self.mass_2)
         self.assertAlmostEqual(self.chirp_mass, chirp_mass)
@@ -465,6 +473,39 @@ class TestGenerateAllParameters(unittest.TestCase):
             for key in expected:
                 self.assertIn(key, new_parameters)
 
+    def test_generate_bbh_paramters_with_likelihood(self):
+        priors = bilby.gw.prior.BBHPriorDict()
+        priors["geocent_time"] = bilby.core.prior.Uniform(0.4, 0.6)
+        ifos = bilby.gw.detector.InterferometerList(["H1"])
+        ifos.set_strain_data_from_power_spectral_densities(duration=1, sampling_frequency=256)
+        wfg = bilby.gw.waveform_generator.WaveformGenerator(
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole
+        )
+        likelihood = bilby.gw.likelihood.GravitationalWaveTransient(
+            interferometers=ifos,
+            waveform_generator=wfg,
+            priors=priors,
+            phase_marginalization=True,
+            time_marginalization=True,
+            reference_frame="H1L1",
+        )
+        self.parameters["zenith"] = 0.0
+        self.parameters["azimuth"] = 0.0
+        del self.parameters["ra"], self.parameters["dec"]
+        converted = bilby.gw.conversion.generate_all_bbh_parameters(
+            sample=self.parameters, likelihood=likelihood, priors=priors
+        )
+        extra_expected = [
+            "geocent_time",
+            "phase",
+            "H1_optimal_snr",
+            "H1_matched_filter_snr",
+            "ra",
+            "dec",
+        ]
+        for key in extra_expected:
+            self.assertIn(key, converted)
+
 
 class TestDistanceTransformations(unittest.TestCase):
     def setUp(self):
@@ -490,5 +531,130 @@ class TestDistanceTransformations(unittest.TestCase):
         self.assertAlmostEqual(max(abs(dl - self.distances)), 0, 4)
 
 
+class TestGenerateMassParameters(unittest.TestCase):
+    def setUp(self):
+        self.expected_values = {'mass_1': 2.0,
+                                'mass_2': 1.0,
+                                'chirp_mass': 1.2167286837864113,
+                                'total_mass': 3.0,
+                                'mass_1_source': 4.0,
+                                'mass_2_source': 2.0,
+                                'chirp_mass_source': 2.433457367572823,
+                                'total_mass_source': 6,
+                                'symmetric_mass_ratio': 0.2222222222222222,
+                                'mass_ratio': 0.5}
+
+    def helper_generation_from_keys(self, keys, expected_values, source=False):
+        # Explicitly test the helper generate_component_masses
+        local_test_vars = \
+            {key: expected_values[key] for key in keys}
+        local_test_vars_with_component_masses = \
+            conversion.generate_component_masses(local_test_vars, source=source)
+        if source:
+            self.assertTrue("mass_1_source" in local_test_vars_with_component_masses.keys())
+            self.assertTrue("mass_2_source" in local_test_vars_with_component_masses.keys())
+        else:
+            self.assertTrue("mass_1" in local_test_vars_with_component_masses.keys())
+            self.assertTrue("mass_2" in local_test_vars_with_component_masses.keys())
+        for key in local_test_vars_with_component_masses.keys():
+            self.assertAlmostEqual(
+                local_test_vars_with_component_masses[key],
+                self.expected_values[key])
+
+        # Test the function more generally
+        local_all_mass_parameters = \
+            conversion.generate_mass_parameters(local_test_vars, source=source)
+        if source:
+            self.assertEqual(
+                set(local_all_mass_parameters.keys()),
+                set(["mass_1_source",
+                     "mass_2_source",
+                     "chirp_mass_source",
+                     "total_mass_source",
+                     "symmetric_mass_ratio",
+                     "mass_ratio",
+                     ]
+                    )
+            )
+        else:
+            self.assertEqual(
+                set(local_all_mass_parameters.keys()),
+                set(["mass_1",
+                     "mass_2",
+                     "chirp_mass",
+                     "total_mass",
+                     "symmetric_mass_ratio",
+                     "mass_ratio",
+                     ]
+                    )
+            )
+        for key in local_all_mass_parameters.keys():
+            self.assertAlmostEqual(expected_values[key], local_all_mass_parameters[key])
+
+    def test_from_mass_1_and_mass_2(self):
+        self.helper_generation_from_keys(["mass_1", "mass_2"],
+                                         self.expected_values)
+
+    def test_from_mass_1_and_mass_ratio(self):
+        self.helper_generation_from_keys(["mass_1", "mass_ratio"],
+                                         self.expected_values)
+
+    def test_from_mass_2_and_mass_ratio(self):
+        self.helper_generation_from_keys(["mass_2", "mass_ratio"],
+                                         self.expected_values)
+
+    def test_from_mass_1_and_total_mass(self):
+        self.helper_generation_from_keys(["mass_2", "total_mass"],
+                                         self.expected_values)
+
+    def test_from_chirp_mass_and_mass_ratio(self):
+        self.helper_generation_from_keys(["chirp_mass", "mass_ratio"],
+                                         self.expected_values)
+
+    def test_from_chirp_mass_and_symmetric_mass_ratio(self):
+        self.helper_generation_from_keys(["chirp_mass", "symmetric_mass_ratio"],
+                                         self.expected_values)
+
+    def test_from_chirp_mass_and_symmetric_mass_1(self):
+        self.helper_generation_from_keys(["chirp_mass", "mass_1"],
+                                         self.expected_values)
+
+    def test_from_chirp_mass_and_symmetric_mass_2(self):
+        self.helper_generation_from_keys(["chirp_mass", "mass_2"],
+                                         self.expected_values)
+
+    def test_from_mass_1_source_and_mass_2_source(self):
+        self.helper_generation_from_keys(["mass_1_source", "mass_2_source"],
+                                         self.expected_values, source=True)
+
+    def test_from_mass_1_source_and_mass_ratio(self):
+        self.helper_generation_from_keys(["mass_1_source", "mass_ratio"],
+                                         self.expected_values, source=True)
+
+    def test_from_mass_2_source_and_mass_ratio(self):
+        self.helper_generation_from_keys(["mass_2_source", "mass_ratio"],
+                                         self.expected_values, source=True)
+
+    def test_from_mass_1_source_and_total_mass(self):
+        self.helper_generation_from_keys(["mass_2_source", "total_mass_source"],
+                                         self.expected_values, source=True)
+
+    def test_from_chirp_mass_source_and_mass_ratio(self):
+        self.helper_generation_from_keys(["chirp_mass_source", "mass_ratio"],
+                                         self.expected_values, source=True)
+
+    def test_from_chirp_mass_source_and_symmetric_mass_ratio(self):
+        self.helper_generation_from_keys(["chirp_mass_source", "symmetric_mass_ratio"],
+                                         self.expected_values, source=True)
+
+    def test_from_chirp_mass_source_and_symmetric_mass_1(self):
+        self.helper_generation_from_keys(["chirp_mass_source", "mass_1_source"],
+                                         self.expected_values, source=True)
+
+    def test_from_chirp_mass_source_and_symmetric_mass_2(self):
+        self.helper_generation_from_keys(["chirp_mass_source", "mass_2_source"],
+                                         self.expected_values, source=True)
+
+
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/gw/detector/geometry_test.py b/test/gw/detector/geometry_test.py
index 3960c196471dfe4fd6856f0ecebd2aafc401f187..358825b237cda9729cecf635c3179c02d38d7b8d 100644
--- a/test/gw/detector/geometry_test.py
+++ b/test/gw/detector/geometry_test.py
@@ -138,81 +138,56 @@ class TestInterferometerGeometry(unittest.TestCase):
         self.geometry.latitude = 0
         self.assertTrue(np.array_equal(self.geometry.y, np.array([1])))
 
-    def test_detector_tensor_without_update(self):
-        _ = self.geometry.detector_tensor
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            expected = np.array(
-                [
-                    [-9.24529394e-06, 1.02425803e-04, 3.24550668e-04],
-                    [1.02425803e-04, 1.37390844e-03, -8.61137566e-03],
-                    [3.24550668e-04, -8.61137566e-03, -1.36466315e-03],
-                ]
-            )
-            self.assertTrue(np.allclose(expected, self.geometry.detector_tensor))
-
     def test_detector_tensor_with_x_azimuth_update(self):
-        _ = self.geometry.detector_tensor
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            self.geometry.xarm_azimuth = 1
-            self.assertEqual(0, self.geometry.detector_tensor)
+        original = self.geometry.detector_tensor
+        self.geometry.xarm_azimuth += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_detector_tensor_with_y_azimuth_update(self):
-        _ = self.geometry.detector_tensor
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            self.geometry.yarm_azimuth = 1
-            self.assertEqual(0, self.geometry.detector_tensor)
+        original = self.geometry.detector_tensor
+        self.geometry.yarm_azimuth += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_detector_tensor_with_x_tilt_update(self):
-        _ = self.geometry.detector_tensor
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            self.geometry.xarm_tilt = 1
-            self.assertEqual(0, self.geometry.detector_tensor)
+        original = self.geometry.detector_tensor
+        self.geometry.xarm_tilt += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_detector_tensor_with_y_tilt_update(self):
-        _ = self.geometry.detector_tensor
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            self.geometry.yarm_tilt = 1
-            self.assertEqual(0, self.geometry.detector_tensor)
+        original = self.geometry.detector_tensor
+        self.geometry.yarm_tilt += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_detector_tensor_with_longitude_update(self):
-        with mock.patch("numpy.einsum") as m:
-            m.return_value = 1
-            self.geometry.longitude = 1
-            self.assertEqual(0, self.geometry.detector_tensor)
+        original = self.geometry.detector_tensor
+        self.geometry.longitude += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_detector_tensor_with_latitude_update(self):
-        with mock.patch("numpy.einsum") as m:
-            _ = self.geometry.detector_tensor
-            m.return_value = 1
-            self.geometry.latitude = 1
-            self.assertEqual(self.geometry.detector_tensor, 0)
+        original = self.geometry.detector_tensor
+        self.geometry.latitude += 1
+        self.assertNotEqual(np.max(abs(self.geometry.detector_tensor - original)), 0)
 
     def test_unit_vector_along_arm_default(self):
         with self.assertRaises(ValueError):
             self.geometry.unit_vector_along_arm("z")
 
     def test_unit_vector_along_arm_x(self):
-        with mock.patch("numpy.array") as m:
-            m.return_value = 1
-            self.geometry.xarm_tilt = 0
-            self.geometry.xarm_azimuth = 0
-            self.geometry.yarm_tilt = 0
-            self.geometry.yarm_azimuth = 90
-            self.assertAlmostEqual(self.geometry.unit_vector_along_arm("x"), 1)
+        self.geometry.longitude = 0
+        self.geometry.latitude = 0
+        self.geometry.xarm_tilt = 0
+        self.geometry.xarm_azimuth = 0
+        arm = self.geometry.unit_vector_along_arm("x")
+        self.assertTrue(np.allclose(arm, np.array([0, 1, 0])))
 
     def test_unit_vector_along_arm_y(self):
-        with mock.patch("numpy.array") as m:
-            m.return_value = 1
-            self.geometry.xarm_tilt = 0
-            self.geometry.xarm_azimuth = 90
-            self.geometry.yarm_tilt = 0
-            self.geometry.yarm_azimuth = 180
-            self.assertAlmostEqual(self.geometry.unit_vector_along_arm("y"), -1)
+        self.geometry.longitude = 0
+        self.geometry.latitude = 0
+        self.geometry.yarm_tilt = 0
+        self.geometry.yarm_azimuth = 90
+        arm = self.geometry.unit_vector_along_arm("y")
+        print(arm)
+        self.assertTrue(np.allclose(arm, np.array([0, 0, 1])))
 
     def test_repr(self):
         expected = (
diff --git a/test/gw/detector/interferometer_test.py b/test/gw/detector/interferometer_test.py
index 3eea3c59d4667bd3fad5232bfb866ca5b98b636d..ad324e00726980b555fec731eb61eeaa2418fffe 100644
--- a/test/gw/detector/interferometer_test.py
+++ b/test/gw/detector/interferometer_test.py
@@ -97,21 +97,6 @@ class TestInterferometer(unittest.TestCase):
     def test_max_freq_setting(self):
         self.assertEqual(self.ifo.strain_data.maximum_frequency, self.maximum_frequency)
 
-    def test_antenna_response_default(self):
-        with mock.patch("bilby.gw.utils.get_polarization_tensor") as m:
-            with mock.patch("numpy.einsum") as n:
-                m.return_value = 0
-                n.return_value = 1
-                self.assertEqual(self.ifo.antenna_response(234, 52, 54, 76, "plus"), 1)
-
-    def test_antenna_response_einsum(self):
-        with mock.patch("bilby.gw.utils.get_polarization_tensor") as m:
-            m.return_value = np.ones((3, 3))
-            self.assertAlmostEqual(
-                self.ifo.antenna_response(234, 52, 54, 76, "plus"),
-                self.ifo.detector_tensor.sum(),
-            )
-
     def test_get_detector_response_default_behaviour(self):
         self.ifo.antenna_response = mock.MagicMock(return_value=1)
         self.ifo.time_delay_from_geocenter = mock.MagicMock(return_value=0)
@@ -315,16 +300,6 @@ class TestInterferometer(unittest.TestCase):
         with self.assertRaises(ValueError):
             self.ifo.inject_signal(injection_polarizations=None, parameters=None)
 
-    def test_time_delay_from_geocenter(self):
-        with mock.patch("bilby.gw.utils.time_delay_geocentric") as m:
-            m.return_value = 1
-            self.assertEqual(self.ifo.time_delay_from_geocenter(1, 2, 3), 1)
-
-    def test_vertex_position_geocentric(self):
-        with mock.patch("bilby.gw.utils.get_vertex_position_geocentric") as m:
-            m.return_value = 1
-            self.assertEqual(self.ifo.vertex_position_geocentric(), 1)
-
     def test_optimal_snr_squared(self):
         """
         Merely checks parameters are given in the right order and the frequency
@@ -563,6 +538,25 @@ class TestInterferometerAntennaPatternAgainstLAL(unittest.TestCase):
                 with self.subTest(':'.join((ifo_name, pol))):
                     self.assertAlmostEqual(std[m], 0.0, places=7)
 
+    def test_time_delay_vs_lal(self):
+        delays = np.zeros(self.trial)
+
+        for n, ifo_name in enumerate(self.ifo_names):
+            ifo = self.ifos[n]
+            det = lal.cached_detector_by_prefix[self.lal_prefixes[ifo_name]]
+            for i in range(self.trial):
+                gpstime = np.random.uniform(1205303144, 1405303144)
+                ra = 2. * np.pi * np.random.uniform()
+                dec = np.pi * np.random.uniform() - np.pi / 2.
+                delays[i] = (
+                    lal.TimeDelayFromEarthCenter(det.location, ra, dec, gpstime)
+                    - ifo.time_delay_from_geocenter(ra, dec, gpstime)
+                )
+
+            std = max(abs(delays))
+            with self.subTest(ifo_name):
+                self.assertAlmostEqual(std, 0.0, places=10)
+
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py
index 5937df0a6000db86ac45df7a020b36bee4ca409d..eeef6c65c420f740739577fef883e6912e928f63 100644
--- a/test/gw/likelihood_test.py
+++ b/test/gw/likelihood_test.py
@@ -1,3 +1,4 @@
+import itertools
 import os
 import pytest
 import unittest
@@ -648,6 +649,43 @@ class TestMarginalizations(unittest.TestCase):
             prior=prior,
         )
 
+    @parameterized.expand(
+        itertools.product(["regular", "roq"], *itertools.repeat([True, False], 3)),
+        name_func=lambda func, num, param: (
+            f"{func.__name__}_{num}__{param.args[0]}_" + "_".join([
+                ["D", "P", "T"][ii] for ii, val
+                in enumerate(param.args[1:]) if val
+            ])
+        )
+    )
+    def test_marginalization_reconstruction(self, kind, distance, phase, time):
+        if time and kind == "roq":
+            pytest.skip("Time reconstruction not supported for ROQ likelihood")
+        marginalizations = dict(
+            geocent_time=time,
+            luminosity_distance=distance,
+            phase=phase,
+        )
+        like = self.get_likelihood(
+            kind=kind,
+            distance_marginalization=distance,
+            time_marginalization=time,
+            phase_marginalization=phase,
+        )
+        params = self.parameters.copy()
+        reference_values = dict(
+            luminosity_distance=self.priors["luminosity_distance"].rescale(0.5),
+            geocent_time=self.interferometers.start_time,
+            phase=0.0,
+        )
+        for key in marginalizations:
+            if marginalizations[key]:
+                params[key] = reference_values[key]
+        like.parameters.update(params)
+        output = like.generate_posterior_sample_from_marginalized_likelihood()
+        for key in marginalizations:
+            self.assertFalse(marginalizations[key] and reference_values[key] == output[key])
+
 
 class TestROQLikelihood(unittest.TestCase):
     def setUp(self):
@@ -1010,8 +1048,8 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
 
     """
 
-    _path_to_basis = "/roq_basis/basis.hdf5"
-    _path_to_basis_mb = "/roq_basis/basis_multiband.hdf5"
+    _path_to_basis = "/roq_basis/basis_addcal.hdf5"
+    _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5"
 
     def setUp(self):
         self.minimum_frequency = 20
@@ -1114,11 +1152,12 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
         product(
             [_path_to_basis, _path_to_basis_mb],
             [_path_to_basis, _path_to_basis_mb],
-            [(8, 9), (8, 10.5), (8, 11.5), (8, 12.5), (8, 14)],
-            [1, 2]
+            [(8, 9), (8, 14)],
+            [1, 2],
+            [False, True]
         )
     )
-    def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor):
+    def test_likelihood_accuracy(self, basis_linear, basis_quadratic, mc_range, roq_scale_factor, add_cal_errors):
         "Compare with log likelihood ratios computed by the non-ROQ likelihood"
         self.minimum_frequency *= roq_scale_factor
         self.sampling_frequency *= roq_scale_factor
@@ -1139,6 +1178,25 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
             duration=self.duration,
             start_time=self.injection_parameters["geocent_time"] - self.duration + 1
         )
+
+        if add_cal_errors:
+            spline_calibration_nodes = 10
+            np.random.seed(170817)
+            for ifo in interferometers:
+                prefix = f"recalib_{ifo.name}_"
+                ifo.calibration_model = bilby.gw.calibration.CubicSpline(
+                    prefix=prefix,
+                    minimum_frequency=ifo.minimum_frequency,
+                    maximum_frequency=ifo.maximum_frequency,
+                    n_points=spline_calibration_nodes
+                )
+                for i in range(spline_calibration_nodes):
+                    # 5% in amplitude, 5deg in phase
+                    self.injection_parameters[f"{prefix}amplitude_{i}"] = \
+                        np.random.normal(loc=0, scale=0.05)
+                    self.injection_parameters[f"{prefix}phase_{i}"] = \
+                        np.random.normal(loc=0, scale=5 * np.pi / 180)
+
         waveform_generator = bilby.gw.WaveformGenerator(
             duration=self.duration,
             sampling_frequency=self.sampling_frequency,
@@ -1176,9 +1234,9 @@ class TestROQLikelihoodHDF5(unittest.TestCase):
         # The maximum error of log likelihood ratio. It is set to be larger for roq_scale_factor=1 as the injected SNR
         # is higher.
         if roq_scale_factor == 1:
-            max_llr_error = 1e-1
+            max_llr_error = 5e-1
         elif roq_scale_factor == 2:
-            max_llr_error = 1e-2
+            max_llr_error = 5e-2
         else:
             raise
         for mc in np.linspace(self.priors["chirp_mass"].minimum, self.priors["chirp_mass"].maximum, 11):
@@ -1200,8 +1258,8 @@ class TestCreateROQLikelihood(unittest.TestCase):
 
     """
 
-    _path_to_basis = "/roq_basis/basis.hdf5"
-    _path_to_basis_mb = "/roq_basis/basis_multiband.hdf5"
+    _path_to_basis = "/roq_basis/basis_addcal.hdf5"
+    _path_to_basis_mb = "/roq_basis/basis_multiband_addcal.hdf5"
 
     @parameterized.expand(product([_path_to_basis, _path_to_basis_mb], [_path_to_basis, _path_to_basis_mb]))
     def test_from_hdf5(self, basis_linear, basis_quadratic):
@@ -1487,9 +1545,9 @@ class TestInOutROQWeights(unittest.TestCase):
         )
 
         if multiband:
-            path_to_basis = "/roq_basis/basis_multiband.hdf5"
+            path_to_basis = "/roq_basis/basis_multiband_addcal.hdf5"
         else:
-            path_to_basis = "/roq_basis/basis.hdf5"
+            path_to_basis = "/roq_basis/basis_addcal.hdf5"
         return bilby.gw.likelihood.ROQGravitationalWaveTransient(
             interferometers=interferometers,
             priors=priors,
@@ -1542,6 +1600,24 @@ class TestMBLikelihood(unittest.TestCase):
         for ifo in ifos:
             ifo.minimum_frequency = fmin
 
+        spline_calibration_nodes = 10
+        self.calibration_parameters = {}
+        for ifo in ifos:
+            ifo.calibration_model = bilby.gw.calibration.CubicSpline(
+                prefix=f"recalib_{ifo.name}_",
+                minimum_frequency=ifo.minimum_frequency,
+                maximum_frequency=ifo.maximum_frequency,
+                n_points=spline_calibration_nodes
+            )
+            for i in range(spline_calibration_nodes):
+                self.test_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = 0
+                self.test_parameters[f"recalib_{ifo.name}_phase_{i}"] = 0
+                # Calibration errors of 5% in amplitude and 5 degrees in phase
+                self.calibration_parameters[f"recalib_{ifo.name}_amplitude_{i}"] = \
+                    np.random.normal(loc=0, scale=0.05)
+                self.calibration_parameters[f"recalib_{ifo.name}_phase_{i}"] = \
+                    np.random.normal(loc=0, scale=5 * np.pi / 180)
+
         priors = bilby.gw.prior.BBHPriorDict()
         priors.pop("mass_1")
         priors.pop("mass_2")
@@ -1625,32 +1701,44 @@ class TestMBLikelihood(unittest.TestCase):
             self.mb_more_accurate
         )
 
-    def test_matches_non_mb(self):
+    @parameterized.expand([(False, ), (True, )])
+    def test_matches_non_mb(self, add_cal_errors):
         self.non_mb_22.parameters.update(self.test_parameters)
         self.mb_22.parameters.update(self.test_parameters)
+        if add_cal_errors:
+            self.non_mb_22.parameters.update(self.calibration_parameters)
+            self.mb_22.parameters.update(self.calibration_parameters)
         self.assertLess(
             abs(self.non_mb_22.log_likelihood_ratio() - self.mb_22.log_likelihood_ratio()),
-            1e-2
+            1.5e-2
         )
 
-    def test_ifft_fft(self):
+    @parameterized.expand([(False, ), (True, )])
+    def test_ifft_fft(self, add_cal_errors):
         """
         Check if multi-banding likelihood with (h, h) computed with the
         IFFT-FFT algorithm matches the original likelihood.
         """
         self.non_mb_22.parameters.update(self.test_parameters)
         self.mb_ifftfft_22.parameters.update(self.test_parameters)
+        if add_cal_errors:
+            self.non_mb_22.parameters.update(self.calibration_parameters)
+            self.mb_ifftfft_22.parameters.update(self.calibration_parameters)
         self.assertLess(
             abs(self.non_mb_22.log_likelihood_ratio() - self.mb_ifftfft_22.log_likelihood_ratio()),
-            5e-3
+            6e-3
         )
 
-    def test_homs(self):
+    @parameterized.expand([(False, ), (True, )])
+    def test_homs(self, add_cal_errors):
         """
         Check if multi-banding likelihood matches the original likelihood for higher-order moments.
         """
         self.non_mb_homs.parameters.update(self.test_parameters)
         self.mb_homs.parameters.update(self.test_parameters)
+        if add_cal_errors:
+            self.non_mb_homs.parameters.update(self.calibration_parameters)
+            self.mb_homs.parameters.update(self.calibration_parameters)
         self.assertLess(
             abs(self.non_mb_homs.log_likelihood_ratio() - self.mb_homs.log_likelihood_ratio()),
             1e-3
diff --git a/test/gw/utils_test.py b/test/gw/utils_test.py
index c5f3af46bb84558ab5be67f832de5657f027720d..7946a99c83b99958160636f482f4593103f6c7fe 100644
--- a/test/gw/utils_test.py
+++ b/test/gw/utils_test.py
@@ -36,34 +36,6 @@ class TestGWUtils(unittest.TestCase):
         psd = gwutils.psd_from_freq_series(freq_data, df)
         self.assertTrue(np.all(psd == (freq_data * 2 * df ** 0.5) ** 2))
 
-    def test_time_delay_from_geocenter(self):
-        """
-        The difference in the two detector case is due to rounding error.
-        Different hardware gives different numbers in the last decimal place.
-        """
-        det1 = np.array([0.1, 0.2, 0.3])
-        det2 = np.array([0.1, 0.2, 0.5])
-        ra = 0.5
-        dec = 0.2
-        time = 10
-        self.assertEqual(gwutils.time_delay_geocentric(det1, det1, ra, dec, time), 0)
-        self.assertAlmostEqual(
-            gwutils.time_delay_geocentric(det1, det2, ra, dec, time),
-            1.3253791114055397e-10,
-            14,
-        )
-
-    def test_get_polarization_tensor(self):
-        ra = 1
-        dec = 2.0
-        time = 10
-        psi = 0.1
-        for mode in ["plus", "cross", "breathing", "longitudinal", "x", "y"]:
-            p = gwutils.get_polarization_tensor(ra, dec, time, psi, mode)
-            self.assertEqual(p.shape, (3, 3))
-        with self.assertRaises(ValueError):
-            gwutils.get_polarization_tensor(ra, dec, time, psi, "not-a-mode")
-
     def test_inner_product(self):
         aa = np.array([1, 2, 3])
         bb = np.array([5, 6, 7])
diff --git a/test/gw/waveform_generator_test.py b/test/gw/waveform_generator_test.py
index f2564c6d539fcc8d0dd0f100d87ac47f345dd29b..c4bd5729f32257cecea81d5731c2e6238e16b88c 100644
--- a/test/gw/waveform_generator_test.py
+++ b/test/gw/waveform_generator_test.py
@@ -1,6 +1,8 @@
 import unittest
 from unittest import mock
+
 import bilby
+import lalsimulation
 import numpy as np
 
 
@@ -159,6 +161,69 @@ class TestWaveformArgumentsSetting(unittest.TestCase):
         )
 
 
+class TestLALCBCWaveformArgumentsSetting(unittest.TestCase):
+    def setUp(self):
+        self.kwargs = dict(
+            duration=4,
+            frequency_domain_source_model=bilby.gw.source.lal_binary_black_hole,
+            sampling_frequency=2048,
+        )
+
+    def tearDown(self):
+        del self.kwargs
+
+    def test_spin_reference_enumeration(self):
+        """
+        Verify that the value of the reference enumerator hasn't changed by comparing
+        against a known approximant.
+        """
+        self.assertEqual(
+            lalsimulation.SimInspiralGetSpinFreqFromApproximant(lalsimulation.SEOBNRv3),
+            bilby.gw.waveform_generator.LALCBCWaveformGenerator.LAL_SIM_INSPIRAL_SPINS_FLOW,
+        )
+
+    def test_create_waveform_generator_non_precessing(self):
+        self.kwargs["waveform_arguments"] = dict(
+            minimum_frequency=20.0,
+            reference_frequency=50.0,
+            waveform_approximant="TaylorF2",
+        )
+        wfg = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
+        self.assertDictEqual(
+            wfg.waveform_arguments,
+            dict(
+                minimum_frequency=20.0,
+                reference_frequency=50.0,
+                waveform_approximant="TaylorF2",
+            ),
+        )
+
+    def test_create_waveform_generator_eob_succeeds(self):
+        self.kwargs["waveform_arguments"] = dict(
+            minimum_frequency=20.0,
+            reference_frequency=20.0,
+            waveform_approximant="SEOBNRv3",
+        )
+        wfg = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
+        self.assertDictEqual(
+            wfg.waveform_arguments,
+            dict(
+                minimum_frequency=20.0,
+                reference_frequency=20.0,
+                waveform_approximant="SEOBNRv3",
+            ),
+        )
+
+    def test_create_waveform_generator_eob_fails(self):
+        self.kwargs["waveform_arguments"] = dict(
+            minimum_frequency=20.0,
+            reference_frequency=50.0,
+            waveform_approximant="SEOBNRv3",
+        )
+        with self.assertRaises(ValueError):
+            _ = bilby.gw.waveform_generator.LALCBCWaveformGenerator(**self.kwargs)
+
+
 class TestSetters(unittest.TestCase):
     def setUp(self):
         self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator(
diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py
index 2bf1d355e03cf48a91a3f9ff29b97ec47f10264e..17307e7d7b1897cd5d7580ea74a8029ac755c758 100644
--- a/test/integration/sampler_run_test.py
+++ b/test/integration/sampler_run_test.py
@@ -1,34 +1,100 @@
+import multiprocessing
+import os
+import sys
+import threading
+import time
+from signal import SIGINT
+
+multiprocessing.set_start_method("fork")  # noqa
+
 import unittest
 import pytest
+from parameterized import parameterized
 import shutil
 
 import bilby
 import numpy as np
 
 
+_sampler_kwargs = dict(
+    bilby_mcmc=dict(nsamples=200, printdt=1),
+    cpnest=dict(nlive=100),
+    dnest4=dict(
+        max_num_levels=2,
+        num_steps=10,
+        new_level_interval=10,
+        num_per_step=10,
+        thread_steps=1,
+        num_particles=50,
+        max_pool=1,
+    ),
+    dynesty=dict(nlive=100),
+    dynamic_dynesty=dict(
+        nlive_init=100,
+        nlive_batch=100,
+        dlogz_init=1.0,
+        maxbatch=0,
+        maxcall=100,
+        bound="single",
+    ),
+    emcee=dict(iterations=1000, nwalkers=10),
+    kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
+    nessai=dict(
+        nlive=100,
+        poolsize=1000,
+        max_iteration=1000,
+        max_threads=3,
+    ),
+    nestle=dict(nlive=100),
+    ptemcee=dict(
+        nsamples=100,
+        nwalkers=50,
+        burn_in_act=1,
+        ntemps=1,
+        frac_threshold=0.5,
+    ),
+    PTMCMCSampler=dict(Niter=101, burn=2, isave=100),
+    pymc=dict(draws=50, tune=50, n_init=250),
+    pymultinest=dict(nlive=100),
+    pypolychord=dict(nlive=100),
+    ultranest=dict(nlive=100, temporary_directory=False),
+)
+
+sampler_imports = dict(
+    bilby_mcmc="bilby",
+    dynamic_dynesty="dynesty"
+)
+
+no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest", "pymc"]
+
+
+def slow_func(x, m, c):
+    time.sleep(0.01)
+    return m * x + c
+
+
+def model(x, m, c):
+    return m * x + c
+
+
 class TestRunningSamplers(unittest.TestCase):
     def setUp(self):
         np.random.seed(42)
         bilby.core.utils.command_line_args.bilby_test_mode = False
         self.x = np.linspace(0, 1, 11)
-        self.model = lambda x, m, c: m * x + c
         self.injection_parameters = dict(m=0.5, c=0.2)
         self.sigma = 0.1
-        self.y = self.model(self.x, **self.injection_parameters) + np.random.normal(
+        self.y = model(self.x, **self.injection_parameters) + np.random.normal(
             0, self.sigma, len(self.x)
         )
         self.likelihood = bilby.likelihood.GaussianLikelihood(
-            self.x, self.y, self.model, self.sigma
+            self.x, self.y, model, self.sigma
         )
 
         self.priors = bilby.core.prior.PriorDict()
         self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
         self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
-        self.kwargs = dict(
-            save=False,
-            conversion_function=self.conversion_function,
-            verbose=True,
-        )
+        self._remove_tree()
         bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
 
     @staticmethod
@@ -42,226 +108,83 @@ class TestRunningSamplers(unittest.TestCase):
         del self.likelihood
         del self.priors
         bilby.core.utils.command_line_args.bilby_test_mode = False
-        shutil.rmtree("outdir")
-
-    def test_run_cpnest(self):
-        pytest.importorskip("cpnest")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="cpnest",
-            nlive=100,
-            resume=False,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dnest4(self):
-        pytest.importorskip("dnest4")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dnest4",
-            max_num_levels=2,
-            num_steps=10,
-            new_level_interval=10,
-            num_per_step=10,
-            thread_steps=1,
-            num_particles=50,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dynesty(self):
-        pytest.importorskip("dynesty")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dynesty",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dynamic_dynesty(self):
-        pytest.importorskip("dynesty")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dynamic_dynesty",
-            nlive_init=100,
-            nlive_batch=100,
-            dlogz_init=1.0,
-            maxbatch=0,
-            maxcall=100,
-            bound="single",
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_emcee(self):
-        pytest.importorskip("emcee")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="emcee",
-            iterations=1000,
-            nwalkers=10,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_kombine(self):
-        pytest.importorskip("kombine")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="kombine",
-            iterations=2000,
-            nwalkers=20,
-            autoburnin=False,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_nestle(self):
-        pytest.importorskip("nestle")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="nestle",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_nessai(self):
-        pytest.importorskip("nessai")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="nessai",
-            nlive=100,
-            poolsize=1000,
-            max_iteration=1000,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_pypolychord(self):
-        pytest.importorskip("pypolychord")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pypolychord",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_ptemcee(self):
-        pytest.importorskip("ptemcee")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ptemcee",
-            nsamples=100,
-            nwalkers=50,
-            burn_in_act=1,
-            ntemps=1,
-            frac_threshold=0.5,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    @pytest.mark.xfail(
-        raises=AttributeError,
-        reason="Dependency issue with pymc3 causes attribute error on import",
-    )
-    def test_run_pymc3(self):
-        pytest.importorskip("pymc3")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pymc3",
-            draws=50,
-            tune=50,
-            n_init=250,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_pymultinest(self):
-        pytest.importorskip("pymultinest")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pymultinest",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_PTMCMCSampler(self):
-        pytest.importorskip("PTMCMCSampler")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="PTMCMCsampler",
-            Niter=101,
-            burn=2,
-            isave=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_ultranest(self):
-        pytest.importorskip("ultranest")
-        # run using NestedSampler (with nlive specified)
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ultranest",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-        # run using ReactiveNestedSampler (with no nlive given)
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ultranest",
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_bilby_mcmc(self):
+        self._remove_tree()
+
+    def _remove_tree(self):
+        try:
+            shutil.rmtree("outdir")
+        except OSError:
+            pass
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_run_sampler_single(self, sampler):
+        self._run_sampler(sampler, pool_size=1)
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_run_sampler_pool(self, sampler):
+        self._run_sampler(sampler, pool_size=2)
+
+    def _run_sampler(self, sampler, pool_size, **extra_kwargs):
+        pytest.importorskip(sampler_imports.get(sampler, sampler))
+        if pool_size > 1 and sampler.lower() in no_pool_test:
+            pytest.skip(f"{sampler} cannot be parallelized")
+        bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
+        kwargs = _sampler_kwargs[sampler]
         res = bilby.run_sampler(
             likelihood=self.likelihood,
             priors=self.priors,
-            sampler="bilby_mcmc",
-            nsamples=200,
-            **self.kwargs,
-            printdt=1,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
+            sampler=sampler,
+            save=False,
+            npool=pool_size,
+            conversion_function=self.conversion_function,
+            **kwargs,
+            **extra_kwargs,
+        )
+        assert "derived" in res.posterior
+        assert res.log_likelihood_evaluations is not None
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_interrupt_sampler_single(self, sampler):
+        self._run_with_signal_handling(sampler, pool_size=1)
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_interrupt_sampler_pool(self, sampler):
+        self._run_with_signal_handling(sampler, pool_size=2)
+
+    def _run_with_signal_handling(self, sampler, pool_size=1):
+        pytest.importorskip(sampler_imports.get(sampler, sampler))
+        if bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler.lower()].hard_exit:
+            pytest.skip(f"{sampler} hard exits, can't test signal handling.")
+        if pool_size > 1 and sampler.lower() in no_pool_test:
+            pytest.skip(f"{sampler} cannot be parallelized")
+        if sys.version_info.minor == 8 and sampler.lower == "cpnest":
+            pytest.skip("Pool interrupting broken for cpnest with py3.8")
+        if sampler.lower() == "nessai" and pool_size > 1:
+            pytest.skip(
+                "Interrupting with a pool is failing in pytest. "
+                "Likely due to interactions with the signal handling in nessai."
+            )
+        pid = os.getpid()
+        print(sampler)
+
+        def trigger_signal():
+            # You could do something more robust, e.g. wait until port is listening
+            time.sleep(4)
+            os.kill(pid, SIGINT)
+
+        thread = threading.Thread(target=trigger_signal)
+        thread.daemon = True
+        thread.start()
+
+        self.likelihood._func = slow_func
+
+        with self.assertRaises((SystemExit, KeyboardInterrupt)):
+            try:
+                while True:
+                    self._run_sampler(sampler=sampler, pool_size=pool_size, exit_code=5)
+            except SystemExit as error:
+                self.assertEqual(error.code, 5)
+                raise
 
 
 if __name__ == "__main__":