diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 12a355c322abc32ddc26e258b6529387d2bd57a3..eaf6150f9741f99d0bdda8d326ca3c877bac648e 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -156,9 +156,10 @@ python-3.9:
   stage: test
   script:
     - python -m pip install .
+    - python -m pip install schwimmbad
     - python -m pip list installed
 
-    - pytest test/integration/sampler_run_test.py --durations 10
+    - pytest test/integration/sampler_run_test.py --durations 10 -v
 
 python-3.8-samplers:
   <<: *test-sampler
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index c2c49cea108d3c1ae687cabc8bbc479cd9b6c596..a78a604a57217f0d579213ef06be28307664da50 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -9,7 +9,7 @@ repos:
     hooks:
       - id: black
         language_version: python3
-        files: '(^bilby/bilby_mcmc/|^examples/)'
+        files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
 -   repo: https://github.com/codespell-project/codespell
     rev: v2.1.0
     hooks:
@@ -20,7 +20,7 @@ repos:
     hooks:
     -   id: isort # sort imports alphabetically and separates import into sections
         args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
-        files: '(^bilby/bilby_mcmc/|^examples/)'
+        files: '^(bilby/bilby_mcmc/|bilby/core/sampler/|examples/)'
 -   repo: https://github.com/datarootsio/databooks
     rev: 0.1.14
     hooks:
diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index 66e6e4392a9456161c123ff84a1d5d73c4ed216f..ec6c2abaaf424f8904149cebdde60095262fe1a0 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -1,6 +1,5 @@
 import datetime
 import os
-import signal
 import time
 from collections import Counter
 
@@ -8,7 +7,13 @@ import numpy as np
 import pandas as pd
 
 from ..core.result import rejection_sample
-from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError
+from ..core.sampler.base_sampler import (
+    MCMCSampler,
+    ResumeError,
+    SamplerError,
+    _sampling_convenience_dump,
+    signal_wrapper,
+)
 from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
 from . import proposals
 from .chain import Chain, Sample
@@ -131,7 +136,6 @@ class Bilby_MCMC(MCMCSampler):
         autocorr_c=5,
         L1steps=100,
         L2steps=3,
-        npool=1,
         printdt=60,
         min_tau=1,
         proposal_cycle="default",
@@ -172,7 +176,6 @@ class Bilby_MCMC(MCMCSampler):
         self.check_point_plot = check_point_plot
         self.diagnostic = diagnostic
         self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
-        self.npool = self.kwargs["npool"]
         self.L1steps = self.kwargs["L1steps"]
         self.L2steps = self.kwargs["L2steps"]
         self.pt_inputs = ParallelTemperingInputs(
@@ -194,17 +197,6 @@ class Bilby_MCMC(MCMCSampler):
         self.verify_configuration()
         self.verbose = verbose
 
-        try:
-            signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-            signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-            signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-        except AttributeError:
-            logger.debug(
-                "Setting signal attributes unavailable on this system. "
-                "This is likely the case if you are running on a Windows machine"
-                " and is no further concern."
-            )
-
     def verify_configuration(self):
         if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
             logger.warning("Burn-in inefficiency fraction greater than 10%")
@@ -223,6 +215,7 @@ class Bilby_MCMC(MCMCSampler):
     def target_nsamples(self):
         return self.kwargs["target_nsamples"]
 
+    @signal_wrapper
     def run_sampler(self):
         self._setup_pool()
         self.setup_chain_set()
@@ -377,31 +370,12 @@ class Bilby_MCMC(MCMCSampler):
             f"setup:\n{self.get_setup_string()}"
         )
 
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """
-        Make sure that if a pool of jobs is running only the parent tries to
-        checkpoint and exit. Only the parent has a 'pool' attribute.
-        """
-        if self.npool == 1 or getattr(self, "pool", None) is not None:
-            if signum == 14:
-                logger.info(
-                    "Run interrupted by alarm signal {}: checkpoint and exit on {}".format(
-                        signum, self.exit_code
-                    )
-                )
-            else:
-                logger.info(
-                    "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                        signum, self.exit_code
-                    )
-                )
-            self.write_current_state()
-            self._close_pool()
-            os._exit(self.exit_code)
-
     def write_current_state(self):
         import dill
 
+        if not hasattr(self, "ptsampler"):
+            logger.debug("Attempted checkpoint before initialization")
+            return
         logger.debug("Check point")
         check_directory_exists_and_if_not_mkdir(self.outdir)
 
@@ -534,39 +508,6 @@ class Bilby_MCMC(MCMCSampler):
                         all_samples=ptsampler.samples,
                     )
 
-    def _setup_pool(self):
-        if self.npool > 1:
-            logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
-            import multiprocessing
-
-            self.pool = multiprocessing.Pool(
-                processes=self.npool,
-                initializer=_initialize_global_variables,
-                initargs=(
-                    self.likelihood,
-                    self.priors,
-                    self._search_parameter_keys,
-                    self.use_ratio,
-                ),
-            )
-        else:
-            self.pool = None
-
-        _initialize_global_variables(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            search_parameter_keys=self._search_parameter_keys,
-            use_ratio=self.use_ratio,
-        )
-
-    def _close_pool(self):
-        if getattr(self, "pool", None) is not None:
-            logger.info("Starting to close worker pool.")
-            self.pool.close()
-            self.pool.join()
-            self.pool = None
-            logger.info("Finished closing worker pool.")
-
 
 class BilbyPTMCMCSampler(object):
     def __init__(
@@ -579,7 +520,6 @@ class BilbyPTMCMCSampler(object):
         use_ratio,
         evidence_method,
     ):
-
         self.set_pt_inputs(pt_inputs)
         self.use_ratio = use_ratio
         self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
@@ -597,7 +537,7 @@ class BilbyPTMCMCSampler(object):
 
         self._nsamples_dict = {}
         self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
-            _priors
+            _sampling_convenience_dump.priors
         )
         self.sampling_time = 0
         self.ln_z_dict = dict()
@@ -612,7 +552,7 @@ class BilbyPTMCMCSampler(object):
         elif pt_inputs.Tmax is not None:
             betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
         elif pt_inputs.Tmax_from_SNR is not None:
-            ndim = len(_priors.non_fixed_keys)
+            ndim = len(_sampling_convenience_dump.priors.non_fixed_keys)
             target_hot_likelihood = ndim / 2
             Tmax = pt_inputs.Tmax_from_SNR**2 / (2 * target_hot_likelihood)
             betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
@@ -1140,12 +1080,14 @@ class BilbyMCMCSampler(object):
         self.Eindex = Eindex
         self.use_ratio = use_ratio
 
-        self.parameters = _priors.non_fixed_keys
+        self.parameters = _sampling_convenience_dump.priors.non_fixed_keys
         self.ndim = len(self.parameters)
 
-        full_sample_dict = _priors.sample()
+        full_sample_dict = _sampling_convenience_dump.priors.sample()
         initial_sample = {
-            k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys
+            k: v
+            for k, v in full_sample_dict.items()
+            if k in _sampling_convenience_dump.priors.non_fixed_keys
         }
         initial_sample = Sample(initial_sample)
         initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
@@ -1168,7 +1110,10 @@ class BilbyMCMCSampler(object):
                 warn = False
 
             self.proposal_cycle = proposals.get_proposal_cycle(
-                proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn
+                proposal_cycle,
+                _sampling_convenience_dump.priors,
+                L1steps=self.chain.L1steps,
+                warn=warn,
             )
         elif isinstance(proposal_cycle, proposals.ProposalCycle):
             self.proposal_cycle = proposal_cycle
@@ -1185,17 +1130,17 @@ class BilbyMCMCSampler(object):
         self.stop_after_convergence = convergence_inputs.stop_after_convergence
 
     def log_likelihood(self, sample):
-        _likelihood.parameters.update(sample.sample_dict)
+        _sampling_convenience_dump.likelihood.parameters.update(sample.sample_dict)
 
         if self.use_ratio:
-            logl = _likelihood.log_likelihood_ratio()
+            logl = _sampling_convenience_dump.likelihood.log_likelihood_ratio()
         else:
-            logl = _likelihood.log_likelihood()
+            logl = _sampling_convenience_dump.likelihood.log_likelihood()
 
         return logl
 
     def log_prior(self, sample):
-        return _priors.ln_prob(sample.parameter_only_dict)
+        return _sampling_convenience_dump.priors.ln_prob(sample.parameter_only_dict)
 
     def accept_proposal(self, prop, proposal):
         self.chain.append(prop)
@@ -1293,8 +1238,10 @@ class BilbyMCMCSampler(object):
         zerotemp_logl = hot_samples[LOGLKEY]
 
         # Revert to true likelihood if needed
-        if _use_ratio:
-            zerotemp_logl += _likelihood.noise_log_likelihood()
+        if _sampling_convenience_dump.use_ratio:
+            zerotemp_logl += (
+                _sampling_convenience_dump.likelihood.noise_log_likelihood()
+            )
 
         # Calculate normalised weights
         log_weights = (1 - beta) * zerotemp_logl
@@ -1322,29 +1269,3 @@ class BilbyMCMCSampler(object):
 def call_step(sampler):
     sampler = sampler.step()
     return sampler
-
-
-_likelihood = None
-_priors = None
-_search_parameter_keys = None
-_use_ratio = False
-
-
-def _initialize_global_variables(
-    likelihood,
-    priors,
-    search_parameter_keys,
-    use_ratio,
-):
-    """
-    Store a global copy of the likelihood, priors, and search keys for
-    multiprocessing.
-    """
-    global _likelihood
-    global _priors
-    global _search_parameter_keys
-    global _use_ratio
-    _likelihood = likelihood
-    _priors = priors
-    _search_parameter_keys = search_parameter_keys
-    _use_ratio = use_ratio
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index bd259f97a74c4784ecf4cf8b5fe2c742e7bc79b0..56c32ed3d88e1c95062954c81f15b78c12b4d792 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -1,15 +1,20 @@
+import datetime
 import inspect
 import sys
-import datetime
 
 import bilby
-from ..utils import command_line_args, logger, loaded_modules_dict
-from ..prior import PriorDict, DeltaFunction
+from bilby.bilby_mcmc import Bilby_MCMC
+
+from ..prior import DeltaFunction, PriorDict
+from ..utils import command_line_args, loaded_modules_dict, logger
+from . import proposal
 from .base_sampler import Sampler, SamplingMarginalisedParameterError
 from .cpnest import Cpnest
+from .dnest4 import DNest4
 from .dynamic_dynesty import DynamicDynesty
 from .dynesty import Dynesty
 from .emcee import Emcee
+from .fake_sampler import FakeSampler
 from .kombine import Kombine
 from .nessai import Nessai
 from .nestle import Nestle
@@ -19,11 +24,7 @@ from .ptmcmc import PTMCMCSampler
 from .pymc3 import Pymc3
 from .pymultinest import Pymultinest
 from .ultranest import Ultranest
-from .fake_sampler import FakeSampler
-from .dnest4 import DNest4
 from .zeus import Zeus
-from bilby.bilby_mcmc import Bilby_MCMC
-from . import proposal
 
 IMPLEMENTED_SAMPLERS = {
     "bilby_mcmc": Bilby_MCMC,
@@ -49,7 +50,7 @@ if command_line_args.sampler_help:
     sampler = command_line_args.sampler_help
     if sampler in IMPLEMENTED_SAMPLERS:
         sampler_class = IMPLEMENTED_SAMPLERS[sampler]
-        print('Help for sampler "{}":'.format(sampler))
+        print(f'Help for sampler "{sampler}":')
         print(sampler_class.__doc__)
     else:
         if sampler == "None":
@@ -58,8 +59,8 @@ if command_line_args.sampler_help:
                 "the name of the sampler"
             )
         else:
-            print("Requested sampler {} not implemented".format(sampler))
-        print("Available samplers = {}".format(IMPLEMENTED_SAMPLERS))
+            print(f"Requested sampler {sampler} not implemented")
+        print(f"Available samplers = {IMPLEMENTED_SAMPLERS}")
 
     sys.exit()
 
@@ -81,7 +82,7 @@ def run_sampler(
     gzip=False,
     result_class=None,
     npool=1,
-    **kwargs
+    **kwargs,
 ):
     """
     The primary interface to easy parameter estimation
@@ -144,9 +145,7 @@ def run_sampler(
         An object containing the results
     """
 
-    logger.info(
-        "Running for label '{}', output will be saved to '{}'".format(label, outdir)
-    )
+    logger.info(f"Running for label '{label}', output will be saved to '{outdir}'")
 
     if clean:
         command_line_args.clean = clean
@@ -174,7 +173,7 @@ def run_sampler(
         meta_data = dict()
     likelihood.label = label
     likelihood.outdir = outdir
-    meta_data['likelihood'] = likelihood.meta_data
+    meta_data["likelihood"] = likelihood.meta_data
     meta_data["loaded_modules"] = loaded_modules_dict()
 
     if command_line_args.bilby_zero_likelihood_mode:
@@ -198,11 +197,11 @@ def run_sampler(
                 plot=plot,
                 result_class=result_class,
                 npool=npool,
-                **kwargs
+                **kwargs,
             )
         else:
             print(IMPLEMENTED_SAMPLERS)
-            raise ValueError("Sampler {} not yet implemented".format(sampler))
+            raise ValueError(f"Sampler {sampler} not yet implemented")
     elif inspect.isclass(sampler):
         sampler = sampler.__init__(
             likelihood,
@@ -214,12 +213,12 @@ def run_sampler(
             injection_parameters=injection_parameters,
             meta_data=meta_data,
             npool=npool,
-            **kwargs
+            **kwargs,
         )
     else:
         raise ValueError(
             "Provided sampler should be a Sampler object or name of a known "
-            "sampler: {}.".format(", ".join(IMPLEMENTED_SAMPLERS.keys()))
+            f"sampler: {', '.join(IMPLEMENTED_SAMPLERS.keys())}."
         )
 
     if sampler.cached_result:
@@ -240,23 +239,22 @@ def run_sampler(
         elif isinstance(result.sampling_time, (float, int)):
             result.sampling_time = datetime.timedelta(result.sampling_time)
 
-        logger.info('Sampling time: {}'.format(result.sampling_time))
+        logger.info(f"Sampling time: {result.sampling_time}")
         # Convert sampling time into seconds
         result.sampling_time = result.sampling_time.total_seconds()
 
         if sampler.use_ratio:
             result.log_noise_evidence = likelihood.noise_log_likelihood()
             result.log_bayes_factor = result.log_evidence
-            result.log_evidence = \
-                result.log_bayes_factor + result.log_noise_evidence
+            result.log_evidence = result.log_bayes_factor + result.log_noise_evidence
         else:
             result.log_noise_evidence = likelihood.noise_log_likelihood()
-            result.log_bayes_factor = \
-                result.log_evidence - result.log_noise_evidence
+            result.log_bayes_factor = result.log_evidence - result.log_noise_evidence
 
         if None not in [result.injection_parameters, conversion_function]:
             result.injection_parameters = conversion_function(
-                result.injection_parameters)
+                result.injection_parameters
+            )
 
         # Initial save of the sampler in case of failure in samples_to_posterior
         if save:
@@ -267,9 +265,12 @@ def run_sampler(
 
     # Check if the posterior has already been created
     if getattr(result, "_posterior", None) is None:
-        result.samples_to_posterior(likelihood=likelihood, priors=result.priors,
-                                    conversion_function=conversion_function,
-                                    npool=npool)
+        result.samples_to_posterior(
+            likelihood=likelihood,
+            priors=result.priors,
+            conversion_function=conversion_function,
+            npool=npool,
+        )
 
     if save:
         # The overwrite here ensures we overwrite the initially stored data
@@ -277,7 +278,7 @@ def run_sampler(
 
     if plot:
         result.plot_corner()
-    logger.info("Summary of results:\n{}".format(result))
+    logger.info(f"Summary of results:\n{result}")
     return result
 
 
@@ -286,7 +287,5 @@ def _check_marginalized_parameters_not_sampled(likelihood, priors):
         if key in priors:
             if not isinstance(priors[key], (float, DeltaFunction)):
                 raise SamplingMarginalisedParameterError(
-                    "Likelihood is {} marginalized but you are trying to sample in {}. ".format(
-                        key, key
-                    )
+                    f"Likelihood is {key} marginalized but you are trying to sample in {key}. "
                 )
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 215104a98087bb91abc3964684edf1a0a5d0d458..c30f76045d285d0362e6d972330694202e4db381 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -1,18 +1,111 @@
 import datetime
 import distutils.dir_util
-import numpy as np
 import os
+import shutil
+import signal
+import sys
 import tempfile
+import time
 
+import attr
+import numpy as np
 from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir, command_line_args, Counter
-from ..prior import Prior, PriorDict, DeltaFunction, Constraint
+from ..prior import Constraint, DeltaFunction, Prior, PriorDict
 from ..result import Result, read_in_result
+from ..utils import (
+    Counter,
+    check_directory_exists_and_if_not_mkdir,
+    command_line_args,
+    logger,
+)
+
+
+@attr.s
+class _SamplingContainer:
+    """
+    A container class for objects that are stored independently in each thread
+    for some samplers.
+
+    A single instance of this will appear in this module that can be access
+    by the individual samplers.
+
+    This includes the:
+
+    - likelihood (bilby.core.likelihood.Likelihood)
+    - priors (bilby.core.prior.PriorDict)
+    - search_parameter_keys (list)
+    - use_ratio (bool)
+    """
+
+    likelihood = attr.ib(default=None)
+    priors = attr.ib(default=None)
+    search_parameter_keys = attr.ib(default=None)
+    use_ratio = attr.ib(default=False)
+
+
+_sampling_convenience_dump = _SamplingContainer()
+
+
+def _initialize_global_variables(
+    likelihood,
+    priors,
+    search_parameter_keys,
+    use_ratio,
+):
+    """
+    Store a global copy of the likelihood, priors, and search keys for
+    multiprocessing.
+    """
+    global _sampling_convenience_dump
+    _sampling_convenience_dump.likelihood = likelihood
+    _sampling_convenience_dump.priors = priors
+    _sampling_convenience_dump.search_parameter_keys = search_parameter_keys
+    _sampling_convenience_dump.use_ratio = use_ratio
+
+
+def signal_wrapper(method):
+    """
+    Decorator to wrap a method of a class to set system signals before running
+    and reset them after.
+
+    Parameters
+    ==========
+    method: callable
+        The method to call, this assumes the first argument is `self`
+        and that `self` has a `write_current_state_and_exit` method.
+
+    Returns
+    =======
+    output: callable
+        The wrapped method.
+    """
+
+    def wrapped(self, *args, **kwargs):
+        try:
+            old_term = signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
+            old_int = signal.signal(signal.SIGINT, self.write_current_state_and_exit)
+            old_alarm = signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
+            _set = True
+        except (AttributeError, ValueError):
+            _set = False
+            logger.debug(
+                "Setting signal attributes unavailable on this system. "
+                "This is likely the case if you are running on a Windows machine "
+                "and can be safely ignored."
+            )
+        output = method(self, *args, **kwargs)
+        if _set:
+            signal.signal(signal.SIGTERM, old_term)
+            signal.signal(signal.SIGINT, old_int)
+            signal.signal(signal.SIGALRM, old_alarm)
+        return output
+
+    return wrapped
 
 
 class Sampler(object):
-    """ A sampler object to aid in setting up an inference run
+    """A sampler object to aid in setting up an inference run
 
     Parameters
     ==========
@@ -76,6 +169,10 @@ class Sampler(object):
         System exit code to return on interrupt
     kwargs: dict
         Dictionary of keyword arguments that can be used in the external sampler
+    hard_exit: bool
+        Whether the implemented sampler exits hard (:code:`os._exit` rather
+        than :code:`sys.exit`). The latter can be escaped as :code:`SystemExit`.
+        The former cannot.
 
     Raises
     ======
@@ -89,15 +186,36 @@ class Sampler(object):
         If some of the priors can't be sampled
 
     """
+
     default_kwargs = dict()
-    npool_equiv_kwargs = ['queue_size', 'threads', 'nthreads', 'npool']
+    npool_equiv_kwargs = [
+        "npool",
+        "queue_size",
+        "threads",
+        "nthreads",
+        "cores",
+        "n_pool",
+    ]
+    hard_exit = False
 
     def __init__(
-            self, likelihood, priors, outdir='outdir', label='label',
-            use_ratio=False, plot=False, skip_import_verification=False,
-            injection_parameters=None, meta_data=None, result_class=None,
-            likelihood_benchmark=False, soft_init=False, exit_code=130,
-            **kwargs):
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        injection_parameters=None,
+        meta_data=None,
+        result_class=None,
+        likelihood_benchmark=False,
+        soft_init=False,
+        exit_code=130,
+        npool=1,
+        **kwargs,
+    ):
         self.likelihood = likelihood
         if isinstance(priors, PriorDict):
             self.priors = priors
@@ -108,6 +226,7 @@ class Sampler(object):
         self.injection_parameters = injection_parameters
         self.meta_data = meta_data
         self.use_ratio = use_ratio
+        self._npool = npool
         if not skip_import_verification:
             self._verify_external_sampler()
         self.external_sampler_function = None
@@ -159,7 +278,7 @@ class Sampler(object):
 
     @property
     def kwargs(self):
-        """dict: Container for the kwargs. Has more sophisticated logic in subclasses """
+        """dict: Container for the kwargs. Has more sophisticated logic in subclasses"""
         return self._kwargs
 
     @kwargs.setter
@@ -170,7 +289,7 @@ class Sampler(object):
         self._verify_kwargs_against_default_kwargs()
 
     def _translate_kwargs(self, kwargs):
-        """ Template for child classes """
+        """Template for child classes"""
         pass
 
     @property
@@ -180,10 +299,11 @@ class Sampler(object):
     def _verify_external_sampler(self):
         external_sampler_name = self.external_sampler_name
         try:
-            self.external_sampler = __import__(external_sampler_name)
+            __import__(external_sampler_name)
         except (ImportError, SystemExit):
             raise SamplerNotInstalledError(
-                "Sampler {} is not installed on this system".format(external_sampler_name))
+                f"Sampler {external_sampler_name} is not installed on this system"
+            )
 
     def _verify_kwargs_against_default_kwargs(self):
         """
@@ -195,8 +315,8 @@ class Sampler(object):
         for user_input in self.kwargs.keys():
             if user_input not in args:
                 logger.warning(
-                    "Supplied argument '{}' not an argument of '{}', removing."
-                    .format(user_input, self.__class__.__name__))
+                    f"Supplied argument '{user_input}' not an argument of '{self.__class__.__name__}', removing."
+                )
                 bad_keys.append(user_input)
         for key in bad_keys:
             self.kwargs.pop(key)
@@ -208,8 +328,10 @@ class Sampler(object):
         the respective parameter is fixed.
         """
         for key in self.priors:
-            if isinstance(self.priors[key], Prior) \
-                    and self.priors[key].is_fixed is False:
+            if (
+                isinstance(self.priors[key], Prior)
+                and self.priors[key].is_fixed is False
+            ):
                 self._search_parameter_keys.append(key)
             elif isinstance(self.priors[key], Constraint):
                 self._constraint_parameter_keys.append(key)
@@ -219,9 +341,9 @@ class Sampler(object):
 
         logger.info("Search parameters:")
         for key in self._search_parameter_keys + self._constraint_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key]))
+            logger.info(f"  {key} = {self.priors[key]}")
         for key in self._fixed_parameter_keys:
-            logger.info('  {} = {}'.format(key, self.priors[key].peak))
+            logger.info(f"  {key} = {self.priors[key].peak}")
 
     def _initialise_result(self, result_class):
         """
@@ -231,27 +353,30 @@ class Sampler(object):
 
         """
         result_kwargs = dict(
-            label=self.label, outdir=self.outdir,
+            label=self.label,
+            outdir=self.outdir,
             sampler=self.__class__.__name__.lower(),
             search_parameter_keys=self._search_parameter_keys,
             fixed_parameter_keys=self._fixed_parameter_keys,
             constraint_parameter_keys=self._constraint_parameter_keys,
-            priors=self.priors, meta_data=self.meta_data,
+            priors=self.priors,
+            meta_data=self.meta_data,
             injection_parameters=self.injection_parameters,
-            sampler_kwargs=self.kwargs, use_ratio=self.use_ratio)
+            sampler_kwargs=self.kwargs,
+            use_ratio=self.use_ratio,
+        )
 
         if result_class is None:
             result = Result(**result_kwargs)
         elif issubclass(result_class, Result):
             result = result_class(**result_kwargs)
         else:
-            raise ValueError(
-                "Input result_class={} not understood".format(result_class))
+            raise ValueError(f"Input result_class={result_class} not understood")
 
         return result
 
     def _verify_parameters(self):
-        """ Evaluate a set of parameters drawn from the prior
+        """Evaluate a set of parameters drawn from the prior
 
         Tests if the likelihood evaluation passes
 
@@ -264,20 +389,22 @@ class Sampler(object):
 
         if self.priors.test_has_redundant_keys():
             raise IllegalSamplingSetError(
-                "Your sampling set contains redundant parameters.")
+                "Your sampling set contains redundant parameters."
+            )
 
         theta = self.priors.sample_subset_constrained_as_array(
-            self.search_parameter_keys, size=1)[:, 0]
+            self.search_parameter_keys, size=1
+        )[:, 0]
         try:
             self.log_likelihood(theta)
         except TypeError as e:
             raise TypeError(
-                "Likelihood evaluation failed with message: \n'{}'\n"
-                "Have you specified all the parameters:\n{}"
-                .format(e, self.likelihood.parameters))
+                f"Likelihood evaluation failed with message: \n'{e}'\n"
+                f"Have you specified all the parameters:\n{self.likelihood.parameters}"
+            )
 
     def _time_likelihood(self, n_evaluations=100):
-        """ Times the likelihood evaluation and print an info message
+        """Times the likelihood evaluation and print an info message
 
         Parameters
         ==========
@@ -289,7 +416,8 @@ class Sampler(object):
         t1 = datetime.datetime.now()
         for _ in range(n_evaluations):
             theta = self.priors.sample_subset_constrained_as_array(
-                self._search_parameter_keys, size=1)[:, 0]
+                self._search_parameter_keys, size=1
+            )[:, 0]
             self.log_likelihood(theta)
         total_time = (datetime.datetime.now() - t1).total_seconds()
         self._log_likelihood_eval_time = total_time / n_evaluations
@@ -298,8 +426,9 @@ class Sampler(object):
             self._log_likelihood_eval_time = np.nan
             logger.info("Unable to measure single likelihood time")
         else:
-            logger.info("Single likelihood evaluation took {:.3e} s"
-                        .format(self._log_likelihood_eval_time))
+            logger.info(
+                f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s"
+            )
 
     def _verify_use_ratio(self):
         """
@@ -309,9 +438,9 @@ class Sampler(object):
         try:
             self.priors.sample_subset(self.search_parameter_keys)
         except (KeyError, AttributeError):
-            logger.error("Cannot sample from priors with keys: {}.".format(
-                self.search_parameter_keys
-            ))
+            logger.error(
+                f"Cannot sample from priors with keys: {self.search_parameter_keys}."
+            )
             raise
         if self.use_ratio is False:
             logger.debug("use_ratio set to False")
@@ -322,14 +451,14 @@ class Sampler(object):
         if self.use_ratio is True and ratio_is_nan:
             logger.warning(
                 "You have requested to use the loglikelihood_ratio, but it "
-                " returns a NaN")
+                " returns a NaN"
+            )
         elif self.use_ratio is None and not ratio_is_nan:
-            logger.debug(
-                "use_ratio not spec. but gives valid answer, setting True")
+            logger.debug("use_ratio not spec. but gives valid answer, setting True")
             self.use_ratio = True
 
     def prior_transform(self, theta):
-        """ Prior transform method that is passed into the external sampler.
+        """Prior transform method that is passed into the external sampler.
 
         Parameters
         ==========
@@ -355,8 +484,7 @@ class Sampler(object):
         float: Joint ln prior probability of theta
 
         """
-        params = {
-            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
         return self.priors.ln_prob(params)
 
     def log_likelihood(self, theta):
@@ -378,8 +506,7 @@ class Sampler(object):
                 self.likelihood_count.increment()
             except AttributeError:
                 pass
-        params = {
-            key: t for key, t in zip(self._search_parameter_keys, theta)}
+        params = {key: t for key, t in zip(self._search_parameter_keys, theta)}
         self.likelihood.parameters.update(params)
         if self.use_ratio:
             return self.likelihood.log_likelihood_ratio()
@@ -387,7 +514,7 @@ class Sampler(object):
             return self.likelihood.log_likelihood()
 
     def get_random_draw_from_prior(self):
-        """ Get a random draw from the prior distribution
+        """Get a random draw from the prior distribution
 
         Returns
         =======
@@ -397,13 +524,12 @@ class Sampler(object):
 
         """
         new_sample = self.priors.sample()
-        draw = np.array(list(new_sample[key]
-                             for key in self._search_parameter_keys))
+        draw = np.array(list(new_sample[key] for key in self._search_parameter_keys))
         self.check_draw(draw)
         return draw
 
     def get_initial_points_from_prior(self, npoints=1):
-        """ Method to draw a set of live points from the prior
+        """Method to draw a set of live points from the prior
 
         This iterates over draws from the prior until all the samples have a
         finite prior and likelihood (relevant for constrained priors).
@@ -457,9 +583,11 @@ class Sampler(object):
         """
         log_p = self.log_prior(theta)
         log_l = self.log_likelihood(theta)
-        return \
-            self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
-            self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
+        return self._check_bad_value(
+            val=log_p, warning=warning, theta=theta, label="prior"
+        ) and self._check_bad_value(
+            val=log_l, warning=warning, theta=theta, label="likelihood"
+        )
 
     @staticmethod
     def _check_bad_value(val, warning, theta, label):
@@ -467,7 +595,7 @@ class Sampler(object):
         bad_values = [np.inf, np.nan_to_num(np.inf)]
         if val in bad_values or np.isnan(val):
             if warning:
-                logger.warning(f'Prior draw {theta} has inf {label}')
+                logger.warning(f"Prior draw {theta} has inf {label}")
             return False
         return True
 
@@ -485,7 +613,7 @@ class Sampler(object):
         raise ValueError("Method not yet implemented")
 
     def _check_cached_result(self):
-        """ Check if the cached data file exists and can be used """
+        """Check if the cached data file exists and can be used"""
 
         if command_line_args.clean:
             logger.debug("Command line argument clean given, forcing rerun")
@@ -493,30 +621,30 @@ class Sampler(object):
             return
 
         try:
-            self.cached_result = read_in_result(
-                outdir=self.outdir, label=self.label)
+            self.cached_result = read_in_result(outdir=self.outdir, label=self.label)
         except IOError:
             self.cached_result = None
 
         if command_line_args.use_cached:
-            logger.debug(
-                "Command line argument cached given, no cache check performed")
+            logger.debug("Command line argument cached given, no cache check performed")
             return
 
         logger.debug("Checking cached data")
         if self.cached_result:
-            check_keys = ['search_parameter_keys', 'fixed_parameter_keys']
+            check_keys = ["search_parameter_keys", "fixed_parameter_keys"]
             use_cache = True
             for key in check_keys:
-                if self.cached_result._check_attribute_match_to_other_object(
-                        key, self) is False:
-                    logger.debug("Cached value {} is unmatched".format(key))
+                if (
+                    self.cached_result._check_attribute_match_to_other_object(key, self)
+                    is False
+                ):
+                    logger.debug(f"Cached value {key} is unmatched")
                     use_cache = False
             try:
                 # Recursive check the dictionaries allowing for numpy arrays
                 np.testing.assert_equal(
                     self.meta_data["likelihood"],
-                    self.cached_result.meta_data["likelihood"]
+                    self.cached_result.meta_data["likelihood"],
                 )
             except AssertionError:
                 use_cache = False
@@ -531,13 +659,12 @@ class Sampler(object):
                 if type(kwargs_print[k]) in (list, np.ndarray):
                     array_repr = np.array(kwargs_print[k])
                     if array_repr.size > 10:
-                        kwargs_print[k] = ('array_like, shape={}'
-                                           .format(array_repr.shape))
+                        kwargs_print[k] = f"array_like, shape={array_repr.shape}"
                 elif type(kwargs_print[k]) == DataFrame:
-                    kwargs_print[k] = ('DataFrame, shape={}'
-                                       .format(kwargs_print[k].shape))
-            logger.info("Using sampler {} with kwargs {}".format(
-                self.__class__.__name__, kwargs_print))
+                    kwargs_print[k] = f"DataFrame, shape={kwargs_print[k].shape}"
+            logger.info(
+                f"Using sampler {self.__class__.__name__} with kwargs {kwargs_print}"
+            )
 
     def calc_likelihood_count(self):
         if self.likelihood_benchmark:
@@ -545,15 +672,100 @@ class Sampler(object):
         else:
             return None
 
+    @property
+    def npool(self):
+        for key in self.npool_equiv_kwargs:
+            if key in self.kwargs:
+                return self.kwargs[key]
+        return self._npool
+
+    def _log_interruption(self, signum=None):
+        if signum == 14:
+            logger.info(
+                f"Run interrupted by alarm signal {signum}: checkpoint and exit on {self.exit_code}"
+            )
+        else:
+            logger.info(
+                f"Run interrupted by signal {signum}: checkpoint and exit on {self.exit_code}"
+            )
+
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Make sure that if a pool of jobs is running only the parent tries to
+        checkpoint and exit. Only the parent has a 'pool' attribute.
+
+        For samplers that must hard exit (typically due to non-Python process)
+        use :code:`os._exit` that cannot be excepted. Other samplers exiting
+        can be caught as a :code:`SystemExit`.
+        """
+        if self.npool in (1, None) or getattr(self, "pool", None) is not None:
+            self._log_interruption(signum=signum)
+            self.write_current_state()
+            self._close_pool()
+            if self.hard_exit:
+                os._exit(self.exit_code)
+            else:
+                sys.exit(self.exit_code)
+
+    def _close_pool(self):
+        if getattr(self, "pool", None) is not None:
+            logger.info("Starting to close worker pool.")
+            self.pool.close()
+            self.pool.join()
+            self.pool = None
+            self.kwargs["pool"] = self.pool
+            logger.info("Finished closing worker pool.")
+
+    def _setup_pool(self):
+        if self.kwargs.get("pool", None) is not None:
+            logger.info("Using user defined pool.")
+            self.pool = self.kwargs["pool"]
+        elif self.npool is not None and self.npool > 1:
+            logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
+            import multiprocessing
+
+            self.pool = multiprocessing.Pool(
+                processes=self.npool,
+                initializer=_initialize_global_variables,
+                initargs=(
+                    self.likelihood,
+                    self.priors,
+                    self._search_parameter_keys,
+                    self.use_ratio,
+                ),
+            )
+        else:
+            self.pool = None
+        _initialize_global_variables(
+            likelihood=self.likelihood,
+            priors=self.priors,
+            search_parameter_keys=self._search_parameter_keys,
+            use_ratio=self.use_ratio,
+        )
+        self.kwargs["pool"] = self.pool
+
+    def write_current_state(self):
+        raise NotImplementedError()
+
 
 class NestedSampler(Sampler):
-    npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints',
-                            'npoint', 'Nlive', 'num_live_points', 'num_particles']
-    walks_equiv_kwargs = ['walks', 'steps', 'nmcmc']
+    npoints_equiv_kwargs = [
+        "nlive",
+        "nlives",
+        "n_live_points",
+        "npoints",
+        "npoint",
+        "Nlive",
+        "num_live_points",
+        "num_particles",
+    ]
+    walks_equiv_kwargs = ["walks", "steps", "nmcmc"]
 
-    def reorder_loglikelihoods(self, unsorted_loglikelihoods, unsorted_samples,
-                               sorted_samples):
-        """ Reorders the stored log-likelihood after they have been reweighted
+    @staticmethod
+    def reorder_loglikelihoods(
+        unsorted_loglikelihoods, unsorted_samples, sorted_samples
+    ):
+        """Reorders the stored log-likelihood after they have been reweighted
 
         This creates a sorting index by matching the reweights `result.samples`
         against the raw samples, then uses this index to sort the
@@ -578,12 +790,12 @@ class NestedSampler(Sampler):
 
         idxs = []
         for ii in range(len(unsorted_loglikelihoods)):
-            idx = np.where(np.all(sorted_samples[ii] == unsorted_samples,
-                                  axis=1))[0]
+            idx = np.where(np.all(sorted_samples[ii] == unsorted_samples, axis=1))[0]
             if len(idx) > 1:
                 logger.warning(
                     "Multiple likelihood matches found between sorted and "
-                    "unsorted samples. Taking the first match.")
+                    "unsorted samples. Taking the first match."
+                )
             idxs.append(idx[0])
         return unsorted_loglikelihoods[idxs]
 
@@ -601,52 +813,34 @@ class NestedSampler(Sampler):
         =======
         float: log_likelihood
         """
-        if self.priors.evaluate_constraints({
-                key: theta[ii] for ii, key in
-                enumerate(self.search_parameter_keys)}):
+        if self.priors.evaluate_constraints(
+            {key: theta[ii] for ii, key in enumerate(self.search_parameter_keys)}
+        ):
             return Sampler.log_likelihood(self, theta)
         else:
             return np.nan_to_num(-np.inf)
 
-    def _setup_run_directory(self):
-        """
-        If using a temporary directory, the output directory is moved to the
-        temporary directory.
-        Used for Dnest4, Pymultinest, and Ultranest.
-        """
-        if self.use_temporary_directory:
-            temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
-            self.temporary_outputfiles_basename = temporary_outputfiles_basename
-
-            if os.path.exists(self.outputfiles_basename):
-                distutils.dir_util.copy_tree(self.outputfiles_basename, self.temporary_outputfiles_basename)
-            check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
-
-            self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
-            logger.info("Using temporary file {}".format(temporary_outputfiles_basename))
-        else:
-            check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-            logger.info("Using output file {}".format(self.outputfiles_basename))
-
 
 class MCMCSampler(Sampler):
-    nwalkers_equiv_kwargs = ['nwalker', 'nwalkers', 'draws', 'Niter']
-    nburn_equiv_kwargs = ['burn', 'nburn']
+    nwalkers_equiv_kwargs = ["nwalker", "nwalkers", "draws", "Niter"]
+    nburn_equiv_kwargs = ["burn", "nburn"]
 
     def print_nburn_logging_info(self):
-        """ Prints logging info as to how nburn was calculated """
+        """Prints logging info as to how nburn was calculated"""
         if type(self.nburn) in [float, int]:
-            logger.info("Discarding {} steps for burn-in".format(self.nburn))
+            logger.info(f"Discarding {self.nburn} steps for burn-in")
         elif self.result.max_autocorrelation_time is None:
-            logger.info("Autocorrelation time not calculated, discarding {} "
-                        " steps for burn-in".format(self.nburn))
+            logger.info(
+                f"Autocorrelation time not calculated, discarding "
+                f"{self.nburn} steps for burn-in"
+            )
         else:
-            logger.info("Discarding {} steps for burn-in, estimated from "
-                        "autocorr".format(self.nburn))
+            logger.info(
+                f"Discarding {self.nburn} steps for burn-in, estimated from autocorr"
+            )
 
     def calculate_autocorrelation(self, samples, c=3):
-        """ Uses the `emcee.autocorr` module to estimate the autocorrelation
+        """Uses the `emcee.autocorr` module to estimate the autocorrelation
 
         Parameters
         ==========
@@ -657,35 +851,155 @@ class MCMCSampler(Sampler):
             estimate (default: `3`). See `emcee.autocorr.integrated_time`.
         """
         import emcee
+
         try:
-            self.result.max_autocorrelation_time = int(np.max(
-                emcee.autocorr.integrated_time(samples, c=c)))
-            logger.info("Max autocorr time = {}".format(
-                self.result.max_autocorrelation_time))
+            self.result.max_autocorrelation_time = int(
+                np.max(emcee.autocorr.integrated_time(samples, c=c))
+            )
+            logger.info(f"Max autocorr time = {self.result.max_autocorrelation_time}")
         except emcee.autocorr.AutocorrError as e:
             self.result.max_autocorrelation_time = None
-            logger.info("Unable to calculate autocorr time: {}".format(e))
+            logger.info(f"Unable to calculate autocorr time: {e}")
+
+
+class _TemporaryFileSamplerMixin:
+    """
+    A mixin class to handle storing sampler intermediate products in a temporary
+    location. See, e.g., `this SO <https://stackoverflow.com/a/547714>` for a
+    basic background on mixins.
+
+    This class makes sure that any subclasses can seamlessly use the temporary
+    file functionality.
+    """
+
+    short_name = ""
+
+    def __init__(self, temporary_directory, **kwargs):
+        super(_TemporaryFileSamplerMixin, self).__init__(**kwargs)
+        self.use_temporary_directory = temporary_directory
+        self._outputfiles_basename = None
+        self._temporary_outputfiles_basename = None
+
+    def _check_and_load_sampling_time_file(self):
+        if os.path.exists(self.time_file_path):
+            with open(self.time_file_path, "r") as time_file:
+                self.total_sampling_time = float(time_file.readline())
+        else:
+            self.total_sampling_time = 0
+
+    def _calculate_and_save_sampling_time(self):
+        current_time = time.time()
+        new_sampling_time = current_time - self.start_time
+        self.total_sampling_time += new_sampling_time
+
+        with open(self.time_file_path, "w") as time_file:
+            time_file.write(str(self.total_sampling_time))
+
+        self.start_time = current_time
+
+    def _clean_up_run_directory(self):
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
+            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
+
+    @property
+    def outputfiles_basename(self):
+        return self._outputfiles_basename
+
+    @outputfiles_basename.setter
+    def outputfiles_basename(self, outputfiles_basename):
+        if outputfiles_basename is None:
+            outputfiles_basename = f"{self.outdir}/{self.short_name}_{self.label}/"
+        if not outputfiles_basename.endswith("/"):
+            outputfiles_basename += "/"
+        check_directory_exists_and_if_not_mkdir(self.outdir)
+        self._outputfiles_basename = outputfiles_basename
+
+    @property
+    def temporary_outputfiles_basename(self):
+        return self._temporary_outputfiles_basename
+
+    @temporary_outputfiles_basename.setter
+    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
+        if not temporary_outputfiles_basename.endswith("/"):
+            temporary_outputfiles_basename += "/"
+        self._temporary_outputfiles_basename = temporary_outputfiles_basename
+        if os.path.exists(self.outputfiles_basename):
+            shutil.copytree(
+                self.outputfiles_basename, self.temporary_outputfiles_basename
+            )
+
+    def write_current_state(self):
+        self._calculate_and_save_sampling_time()
+        if self.use_temporary_directory:
+            self._move_temporary_directory_to_proper_path()
+
+    def _move_temporary_directory_to_proper_path(self):
+        """
+        Move the temporary back to the proper path
+
+        Anything in the proper path at this point is removed including links
+        """
+        self._copy_temporary_directory_contents_to_proper_path()
+        shutil.rmtree(self.temporary_outputfiles_basename)
+
+    def _copy_temporary_directory_contents_to_proper_path(self):
+        """
+        Copy the temporary back to the proper path.
+        Do not delete the temporary directory.
+        """
+        logger.info(
+            f"Overwriting {self.outputfiles_basename} with {self.temporary_outputfiles_basename}"
+        )
+        outputfiles_basename_stripped = self.outputfiles_basename.rstrip("/")
+        distutils.dir_util.copy_tree(
+            self.temporary_outputfiles_basename, outputfiles_basename_stripped
+        )
+
+    def _setup_run_directory(self):
+        """
+        If using a temporary directory, the output directory is moved to the
+        temporary directory.
+        Used for Dnest4, Pymultinest, and Ultranest.
+        """
+        check_directory_exists_and_if_not_mkdir(self.outputfiles_basename)
+        if self.use_temporary_directory:
+            temporary_outputfiles_basename = tempfile.TemporaryDirectory().name
+            self.temporary_outputfiles_basename = temporary_outputfiles_basename
+
+            if os.path.exists(self.outputfiles_basename):
+                distutils.dir_util.copy_tree(
+                    self.outputfiles_basename, self.temporary_outputfiles_basename
+                )
+            check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename)
+
+            self.kwargs["outputfiles_basename"] = self.temporary_outputfiles_basename
+            logger.info(f"Using temporary file {temporary_outputfiles_basename}")
+        else:
+            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
+            logger.info(f"Using output file {self.outputfiles_basename}")
+        self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
 
 
 class Error(Exception):
-    """ Base class for all exceptions raised by this module """
+    """Base class for all exceptions raised by this module"""
 
 
 class SamplerError(Error):
-    """ Base class for Error related to samplers in this module """
+    """Base class for Error related to samplers in this module"""
 
 
 class ResumeError(Error):
-    """ Class for errors arising from resuming runs """
+    """Class for errors arising from resuming runs"""
 
 
 class SamplerNotInstalledError(SamplerError):
-    """ Base class for Error raised by not installed samplers """
+    """Base class for Error raised by not installed samplers"""
 
 
 class IllegalSamplingSetError(Error):
-    """ Class for illegal sets of sampling parameters """
+    """Class for illegal sets of sampling parameters"""
 
 
 class SamplingMarginalisedParameterError(IllegalSamplingSetError):
-    """ Class for errors that occur when sampling over marginalized parameters """
+    """Class for errors that occur when sampling over marginalized parameters"""
diff --git a/bilby/core/sampler/cpnest.py b/bilby/core/sampler/cpnest.py
index e64365f2e423c9ac4968e8c9964f4d26cd55bc00..b124643759fc742465d612918523161b585bb17c 100644
--- a/bilby/core/sampler/cpnest.py
+++ b/bilby/core/sampler/cpnest.py
@@ -1,18 +1,18 @@
-
 import array
 import copy
+import sys
 
 import numpy as np
 from numpy.lib.recfunctions import structured_to_unstructured
 from pandas import DataFrame
 
-from .base_sampler import NestedSampler
-from .proposal import Sample, JumpProposalCycle
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import NestedSampler, signal_wrapper
+from .proposal import JumpProposalCycle, Sample
 
 
 class Cpnest(NestedSampler):
-    """ bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
+    """bilby wrapper of cpnest (https://github.com/johnveitch/cpnest)
 
     All positional and keyword arguments (i.e., the args and kwargs) passed to
     `run_sampler` will be propagated to `cpnest.CPNest`, see documentation
@@ -39,30 +39,44 @@ class Cpnest(NestedSampler):
         {self.outdir}/cpnest_{self.label}/
 
     """
-    default_kwargs = dict(verbose=3, nthreads=1, nlive=500, maxmcmc=1000,
-                          seed=None, poolsize=100, nhamiltonian=0, resume=True,
-                          output=None, proposals=None, n_periodic_checkpoint=8000)
+
+    default_kwargs = dict(
+        verbose=3,
+        nthreads=1,
+        nlive=500,
+        maxmcmc=1000,
+        seed=None,
+        poolsize=100,
+        nhamiltonian=0,
+        resume=True,
+        output=None,
+        proposals=None,
+        n_periodic_checkpoint=8000,
+    )
+    hard_exit = True
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'nthreads' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "nthreads" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nthreads'] = kwargs.pop(equiv)
+                    kwargs["nthreads"] = kwargs.pop(equiv)
 
-        if 'seed' not in kwargs:
-            logger.warning('No seed provided, cpnest will use 1234.')
+        if "seed" not in kwargs:
+            logger.warning("No seed provided, cpnest will use 1234.")
 
+    @signal_wrapper
     def run_sampler(self):
-        from cpnest import model as cpmodel, CPNest
-        from cpnest.parameter import LivePoint
+        from cpnest import CPNest
+        from cpnest import model as cpmodel
         from cpnest.nest2pos import compute_weights
+        from cpnest.parameter import LivePoint
 
         class Model(cpmodel.Model):
-            """ A wrapper class to pass our log_likelihood into cpnest """
+            """A wrapper class to pass our log_likelihood into cpnest"""
 
             def __init__(self, names, priors):
                 self.names = names
@@ -82,14 +96,16 @@ class Cpnest(NestedSampler):
             def _update_bounds(self):
                 self.bounds = [
                     [self.priors[key].minimum, self.priors[key].maximum]
-                    for key in self.names]
+                    for key in self.names
+                ]
 
             def new_point(self):
                 """Draw a point from the prior"""
                 prior_samples = self.priors.sample()
                 self._update_bounds()
                 point = LivePoint(
-                    self.names, array.array('d', [prior_samples[name] for name in self.names])
+                    self.names,
+                    array.array("d", [prior_samples[name] for name in self.names]),
                 )
                 return point
 
@@ -105,18 +121,14 @@ class Cpnest(NestedSampler):
                     kwarg = remove_kwargs.pop(0)
                 else:
                     raise TypeError("Unable to initialise cpnest sampler")
-                logger.info(
-                    "CPNest init. failed with error {}, please update"
-                    .format(e))
-                logger.info(
-                    "Attempting to rerun with kwarg {} removed".format(kwarg))
+                logger.info(f"CPNest init. failed with error {e}, please update")
+                logger.info(f"Attempting to rerun with kwarg {kwarg} removed")
                 self.kwargs.pop(kwarg)
         try:
             out.run()
-        except SystemExit as e:
-            import sys
-            logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code))
-            sys.exit(self.exit_code)
+        except SystemExit:
+            out.checkpoint()
+            self.write_current_state_and_exit()
 
         if self.plot:
             out.plot()
@@ -125,42 +137,58 @@ class Cpnest(NestedSampler):
         self.result.samples = structured_to_unstructured(
             out.posterior_samples[self.search_parameter_keys]
         )
-        self.result.log_likelihood_evaluations = out.posterior_samples['logL']
-        self.result.nested_samples = DataFrame(out.get_nested_samples(filename=''))
-        self.result.nested_samples.rename(columns=dict(logL='log_likelihood'), inplace=True)
-        _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood),
-                                         np.array(out.NS.state.nlive))
-        self.result.nested_samples['weights'] = np.exp(log_weights)
+        self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
+        self.result.nested_samples = DataFrame(out.get_nested_samples(filename=""))
+        self.result.nested_samples.rename(
+            columns=dict(logL="log_likelihood"), inplace=True
+        )
+        _, log_weights = compute_weights(
+            np.array(self.result.nested_samples.log_likelihood),
+            np.array(out.NS.state.nlive),
+        )
+        self.result.nested_samples["weights"] = np.exp(log_weights)
         self.result.log_evidence = out.NS.state.logZ
         self.result.log_evidence_err = np.sqrt(out.NS.state.info / out.NS.state.nlive)
         self.result.information_gain = out.NS.state.info
         return self.result
 
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Overwrites the base class to make sure that :code:`CPNest` terminates
+        properly as :code:`CPNest` handles all the multiprocessing internally.
+        """
+        self._log_interruption(signum=signum)
+        sys.exit(self.exit_code)
+
     def _verify_kwargs_against_default_kwargs(self):
         """
         Set the directory where the output will be written
         and check resume and checkpoint status.
         """
-        if not self.kwargs['output']:
-            self.kwargs['output'] = \
-                '{}/cpnest_{}/'.format(self.outdir, self.label)
-        if self.kwargs['output'].endswith('/') is False:
-            self.kwargs['output'] = '{}/'.format(self.kwargs['output'])
-        check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
-        if self.kwargs['n_periodic_checkpoint'] and not self.kwargs['resume']:
-            self.kwargs['n_periodic_checkpoint'] = None
+        if not self.kwargs["output"]:
+            self.kwargs["output"] = f"{self.outdir}/cpnest_{self.label}/"
+        if self.kwargs["output"].endswith("/") is False:
+            self.kwargs["output"] = f"{self.kwargs['output']}/"
+        check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
+        if self.kwargs["n_periodic_checkpoint"] and not self.kwargs["resume"]:
+            self.kwargs["n_periodic_checkpoint"] = None
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
     def _resolve_proposal_functions(self):
         from cpnest.proposal import ProposalCycle
-        if 'proposals' in self.kwargs:
-            if self.kwargs['proposals'] is None:
+
+        if "proposals" in self.kwargs:
+            if self.kwargs["proposals"] is None:
                 return
-            if type(self.kwargs['proposals']) == JumpProposalCycle:
-                self.kwargs['proposals'] = dict(mhs=self.kwargs['proposals'], hmc=self.kwargs['proposals'])
-            for key, proposal in self.kwargs['proposals'].items():
+            if type(self.kwargs["proposals"]) == JumpProposalCycle:
+                self.kwargs["proposals"] = dict(
+                    mhs=self.kwargs["proposals"], hmc=self.kwargs["proposals"]
+                )
+            for key, proposal in self.kwargs["proposals"].items():
                 if isinstance(proposal, JumpProposalCycle):
-                    self.kwargs['proposals'][key] = cpnest_proposal_cycle_factory(proposal)
+                    self.kwargs["proposals"][key] = cpnest_proposal_cycle_factory(
+                        proposal
+                    )
                 elif isinstance(proposal, ProposalCycle):
                     pass
                 else:
@@ -171,7 +199,6 @@ def cpnest_proposal_factory(jump_proposal):
     import cpnest.proposal
 
     class CPNestEnsembleProposal(cpnest.proposal.EnsembleProposal):
-
         def __init__(self, jp):
             self.jump_proposal = jp
             self.ensemble = None
@@ -181,8 +208,8 @@ def cpnest_proposal_factory(jump_proposal):
 
         def get_sample(self, cpnest_sample, **kwargs):
             sample = Sample.from_cpnest_live_point(cpnest_sample)
-            self.ensemble = kwargs.get('coordinates', self.ensemble)
-            sample = self.jump_proposal(sample=sample, sampler_name='cpnest', **kwargs)
+            self.ensemble = kwargs.get("coordinates", self.ensemble)
+            sample = self.jump_proposal(sample=sample, sampler_name="cpnest", **kwargs)
             self.log_J = self.jump_proposal.log_j
             return self._update_cpnest_sample(cpnest_sample, sample)
 
@@ -203,11 +230,15 @@ def cpnest_proposal_cycle_factory(jump_proposals):
         def __init__(self):
             self.jump_proposals = copy.deepcopy(jump_proposals)
             for i, prop in enumerate(self.jump_proposals.proposal_functions):
-                self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(prop)
+                self.jump_proposals.proposal_functions[i] = cpnest_proposal_factory(
+                    prop
+                )
             self.jump_proposals.update_cycle()
-            super(CPNestProposalCycle, self).__init__(proposals=self.jump_proposals.proposal_functions,
-                                                      weights=self.jump_proposals.weights,
-                                                      cyclelength=self.jump_proposals.cycle_length)
+            super(CPNestProposalCycle, self).__init__(
+                proposals=self.jump_proposals.proposal_functions,
+                weights=self.jump_proposals.weights,
+                cyclelength=self.jump_proposals.cycle_length,
+            )
 
         def get_sample(self, old, **kwargs):
             return self.jump_proposals(sample=old, coordinates=self.ensemble, **kwargs)
diff --git a/bilby/core/sampler/dnest4.py b/bilby/core/sampler/dnest4.py
index ef80c13e933e4dfd6fcaa5e4c3ea8f113b15928e..7d5b97092919e0951de2329de8633e4a6bc7fc2b 100644
--- a/bilby/core/sampler/dnest4.py
+++ b/bilby/core/sampler/dnest4.py
@@ -1,21 +1,17 @@
-import os
-import shutil
-import distutils.dir_util
-import signal
-import time
 import datetime
-import sys
+import time
 
 import numpy as np
 import pandas as pd
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
-from .base_sampler import NestedSampler
+from ..utils import logger
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
 class _DNest4Model(object):
-
-    def __init__(self, log_likelihood_func, from_prior_func, widths, centers, highs, lows):
+    def __init__(
+        self, log_likelihood_func, from_prior_func, widths, centers, highs, lows
+    ):
         """Initialize the DNest4 model.
         Args:
             log_likelihood_func: function
@@ -48,7 +44,7 @@ class _DNest4Model(object):
         """The perturb function to perform Monte Carlo trial moves."""
         idx = np.random.randint(self._n_dim)
 
-        coords[idx] += (self._widths[idx] * (np.random.uniform(size=1) - 0.5))
+        coords[idx] += self._widths[idx] * (np.random.uniform(size=1) - 0.5)
         cw = self._widths[idx]
         cc = self._centers[idx]
 
@@ -59,11 +55,13 @@ class _DNest4Model(object):
     @staticmethod
     def wrap(x, minimum, maximum):
         if maximum <= minimum:
-            raise ValueError("maximum {} <= minimum {}, when trying to wrap coordinates".format(maximum, minimum))
+            raise ValueError(
+                f"maximum {maximum} <= minimum {minimum}, when trying to wrap coordinates"
+            )
         return (x - minimum) % (maximum - minimum) + minimum
 
 
-class DNest4(NestedSampler):
+class DNest4(_TemporaryFileSamplerMixin, NestedSampler):
 
     """
     Bilby wrapper of DNest4
@@ -100,35 +98,58 @@ class DNest4(NestedSampler):
         If True, prints information during run
     """
 
-    default_kwargs = dict(max_num_levels=20, num_steps=500,
-                          new_level_interval=10000, num_per_step=10000,
-                          thread_steps=1, num_particles=1000, lam=10.0,
-                          beta=100, seed=None, verbose=True, outputfiles_basename=None,
-                          backend='memory')
-
-    def __init__(self, likelihood, priors, outdir="outdir", label="label", use_ratio=False, plot=False,
-                 exit_code=77, skip_import_verification=False, temporary_directory=True, **kwargs):
+    default_kwargs = dict(
+        max_num_levels=20,
+        num_steps=500,
+        new_level_interval=10000,
+        num_per_step=10000,
+        thread_steps=1,
+        num_particles=1000,
+        lam=10.0,
+        beta=100,
+        seed=None,
+        verbose=True,
+        outputfiles_basename=None,
+        backend="memory",
+    )
+    short_name = "dn4"
+    hard_exit = True
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        exit_code=77,
+        skip_import_verification=False,
+        temporary_directory=True,
+        **kwargs,
+    ):
         super(DNest4, self).__init__(
-            likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-            use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
-            exit_code=exit_code, **kwargs)
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            temporary_directory=temporary_directory,
+            exit_code=exit_code,
+            **kwargs,
+        )
 
         self.num_particles = self.kwargs["num_particles"]
         self.max_num_levels = self.kwargs["max_num_levels"]
         self._verbose = self.kwargs["verbose"]
         self._backend = self.kwargs["backend"]
-        self.use_temporary_directory = temporary_directory
 
         self.start_time = np.nan
         self.sampler = None
         self._information = np.nan
         self._last_live_sample_info = np.nan
-        self._outputfiles_basename = None
-        self._temporary_outputfiles_basename = None
-
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
 
         # Get the estimates of the prior distributions' widths and centers.
         widths = []
@@ -155,13 +176,22 @@ class DNest4(NestedSampler):
         self._highs = np.array(highs)
         self._lows = np.array(lows)
 
-        self._dnest4_model = _DNest4Model(self.log_likelihood, self.get_random_draw_from_prior, self._widths,
-                                          self._centers, self._highs, self._lows)
+        self._dnest4_model = _DNest4Model(
+            self.log_likelihood,
+            self.get_random_draw_from_prior,
+            self._widths,
+            self._centers,
+            self._highs,
+            self._lows,
+        )
 
     def _set_backend(self):
         import dnest4
-        if self._backend == 'csv':
-            return dnest4.backends.CSVBackend("{}/dnest4{}/".format(self.outdir, self.label), sep=" ")
+
+        if self._backend == "csv":
+            return dnest4.backends.CSVBackend(
+                f"{self.outdir}/dnest4{self.label}/", sep=" "
+            )
         else:
             return dnest4.backends.MemoryBackend()
 
@@ -169,6 +199,7 @@ class DNest4(NestedSampler):
         dnest4_keys = ["num_steps", "new_level_interval", "lam", "beta", "seed"]
         self.dnest4_kwargs = {key: self.kwargs[key] for key in dnest4_keys}
 
+    @signal_wrapper
     def run_sampler(self):
         import dnest4
 
@@ -181,31 +212,37 @@ class DNest4(NestedSampler):
         self.start_time = time.time()
 
         self.sampler = dnest4.DNest4Sampler(self._dnest4_model, backend=backend)
-        out = self.sampler.sample(self.max_num_levels,
-                                  num_particles=self.num_particles,
-                                  **self.dnest4_kwargs)
+        out = self.sampler.sample(
+            self.max_num_levels, num_particles=self.num_particles, **self.dnest4_kwargs
+        )
 
         for i, sample in enumerate(out):
             if self._verbose and ((i + 1) % 100 == 0):
                 stats = self.sampler.postprocess()
-                logger.info("Iteration: {0} log(Z): {1}".format(i + 1, stats['log_Z']))
+                logger.info(f"Iteration: {i + 1} log(Z): {stats['log_Z']}")
 
         self._calculate_and_save_sampling_time()
         self._clean_up_run_directory()
 
         stats = self.sampler.postprocess(resample=1)
-        self.result.log_evidence = stats['log_Z']
-        self._information = stats['H']
+        self.result.log_evidence = stats["log_Z"]
+        self._information = stats["H"]
         self.result.log_evidence_err = np.sqrt(self._information / self.num_particles)
 
-        if self._backend == 'memory':
-            self._last_live_sample_info = pd.DataFrame(self.sampler.backend.sample_info[-1])
-            self.result.log_likelihood_evaluations = self._last_live_sample_info['log_likelihood']
+        if self._backend == "memory":
+            self._last_live_sample_info = pd.DataFrame(
+                self.sampler.backend.sample_info[-1]
+            )
+            self.result.log_likelihood_evaluations = self._last_live_sample_info[
+                "log_likelihood"
+            ]
             self.result.samples = np.array(self.sampler.backend.posterior_samples)
         else:
-            sample_info_path = './' + self.kwargs["outputfiles_basename"] + '/sample_info.txt'
-            sample_info = np.genfromtxt(sample_info_path, comments='#', names=True)
-            self.result.log_likelihood_evaluations = sample_info['log_likelihood']
+            sample_info_path = (
+                "./" + self.kwargs["outputfiles_basename"] + "/sample_info.txt"
+            )
+            sample_info = np.genfromtxt(sample_info_path, comments="#", names=True)
+            self.result.log_likelihood_evaluations = sample_info["log_likelihood"]
             self.result.samples = np.array(self.sampler.backend.posterior_samples)
 
         self.result.sampler_output = out
@@ -217,100 +254,11 @@ class DNest4(NestedSampler):
         return self.result
 
     def _translate_kwargs(self, kwargs):
-        if 'num_steps' not in kwargs:
+        if "num_steps" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['num_steps'] = kwargs.pop(equiv)
+                    kwargs["num_steps"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         self.outputfiles_basename = self.kwargs.pop("outputfiles_basename", None)
         super(DNest4, self)._verify_kwargs_against_default_kwargs()
-
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = self.kwargs["outputfiles_basename"] + '/sampling_time.dat'
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, 'r') as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-
-        with open(self.time_file_path, 'w') as time_file:
-            time_file.write(str(self.total_sampling_time))
-
-        self.start_time = current_time
-
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = "{}/dnest4{}/".format(self.outdir, self.label)
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """ Write current state and exit on exit_code """
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        sys.exit(self.exit_code)
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Move the temporary back to the proper path
-
-        Anything in the proper path at this point is removed including links
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
-
-    def _copy_temporary_directory_contents_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path.
-        Do not delete the temporary directory.
-        """
-        logger.info(
-            "Overwriting {} with {}".format(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-        )
-        if self.outputfiles_basename.endswith('/'):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(self.temporary_outputfiles_basename, outputfiles_basename_stripped)
diff --git a/bilby/core/sampler/dynamic_dynesty.py b/bilby/core/sampler/dynamic_dynesty.py
index 8bb6d647aad013c5be957cfde6311f10f9feda82..ef28f22ddbb2ce099d481c3adebaae1ef1d3b0cd 100644
--- a/bilby/core/sampler/dynamic_dynesty.py
+++ b/bilby/core/sampler/dynamic_dynesty.py
@@ -1,12 +1,10 @@
-
-import os
-import signal
+import datetime
 
 import numpy as np
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import Sampler
-from .dynesty import Dynesty
+from ..utils import logger
+from .base_sampler import Sampler, signal_wrapper
+from .dynesty import Dynesty, _log_likelihood_wrapper, _prior_transform_wrapper
 
 
 class DynamicDynesty(Dynesty):
@@ -62,33 +60,77 @@ class DynamicDynesty(Dynesty):
     resume: bool
         If true, resume run from checkpoint (if available)
     """
-    default_kwargs = dict(bound='multi', sample='rwalk',
-                          verbose=True,
-                          check_point_delta_t=600,
-                          first_update=None,
-                          npdim=None, rstate=None, queue_size=None, pool=None,
-                          use_pool=None,
-                          logl_args=None, logl_kwargs=None,
-                          ptform_args=None, ptform_kwargs=None,
-                          enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
-                          facc=0.5, slices=5,
-                          walks=None, update_interval=0.6,
-                          nlive_init=500, maxiter_init=None, maxcall_init=None,
-                          dlogz_init=0.01, logl_max_init=np.inf, nlive_batch=500,
-                          wt_function=None, wt_kwargs=None, maxiter_batch=None,
-                          maxcall_batch=None, maxiter=None, maxcall=None,
-                          maxbatch=None, stop_function=None, stop_kwargs=None,
-                          use_stop=True, save_bounds=True,
-                          print_progress=True, print_func=None, live_points=None,
-                          )
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
-                 skip_import_verification=False, check_point=True, n_check_point=None, check_point_delta_t=600,
-                 resume=True, **kwargs):
-        super(DynamicDynesty, self).__init__(likelihood=likelihood, priors=priors,
-                                             outdir=outdir, label=label, use_ratio=use_ratio,
-                                             plot=plot, skip_import_verification=skip_import_verification,
-                                             **kwargs)
+
+    default_kwargs = dict(
+        bound="multi",
+        sample="rwalk",
+        verbose=True,
+        check_point_delta_t=600,
+        first_update=None,
+        npdim=None,
+        rstate=None,
+        queue_size=None,
+        pool=None,
+        use_pool=None,
+        logl_args=None,
+        logl_kwargs=None,
+        ptform_args=None,
+        ptform_kwargs=None,
+        enlarge=None,
+        bootstrap=None,
+        vol_dec=0.5,
+        vol_check=2.0,
+        facc=0.5,
+        slices=5,
+        walks=None,
+        update_interval=0.6,
+        nlive_init=500,
+        maxiter_init=None,
+        maxcall_init=None,
+        dlogz_init=0.01,
+        logl_max_init=np.inf,
+        nlive_batch=500,
+        wt_function=None,
+        wt_kwargs=None,
+        maxiter_batch=None,
+        maxcall_batch=None,
+        maxiter=None,
+        maxcall=None,
+        maxbatch=None,
+        stop_function=None,
+        stop_kwargs=None,
+        use_stop=True,
+        save_bounds=True,
+        print_progress=True,
+        print_func=None,
+        live_points=None,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        check_point=True,
+        n_check_point=None,
+        check_point_delta_t=600,
+        resume=True,
+        **kwargs,
+    ):
+        super(DynamicDynesty, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
         self.n_check_point = n_check_point
         self.check_point = check_point
         self.resume = resume
@@ -97,39 +139,59 @@ class DynamicDynesty(Dynesty):
             # check_point is set to False.
             if np.isnan(self._log_likelihood_eval_time):
                 self.check_point = False
-            n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time)
-            n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw)))
+            n_check_point_raw = check_point_delta_t / self._log_likelihood_eval_time
+            n_check_point_rnd = int(float(f"{n_check_point_raw:1.0g}"))
             self.n_check_point = n_check_point_rnd
 
-        self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
-
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
+        self.resume_file = f"{self.outdir}/{self.label}_resume.pickle"
 
     @property
     def external_sampler_name(self):
-        return 'dynesty'
+        return "dynesty"
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['nlive_init', 'maxiter_init', 'maxcall_init', 'dlogz_init',
-                'logl_max_init', 'nlive_batch', 'wt_function', 'wt_kwargs',
-                'maxiter_batch', 'maxcall_batch', 'maxiter', 'maxcall',
-                'maxbatch', 'stop_function', 'stop_kwargs', 'use_stop',
-                'save_bounds', 'print_progress', 'print_func', 'live_points']
+        keys = [
+            "nlive_init",
+            "maxiter_init",
+            "maxcall_init",
+            "dlogz_init",
+            "logl_max_init",
+            "nlive_batch",
+            "wt_function",
+            "wt_kwargs",
+            "maxiter_batch",
+            "maxcall_batch",
+            "maxiter",
+            "maxcall",
+            "maxbatch",
+            "stop_function",
+            "stop_kwargs",
+            "use_stop",
+            "save_bounds",
+            "print_progress",
+            "print_func",
+            "live_points",
+        ]
         return {key: self.kwargs[key] for key in keys}
 
+    @signal_wrapper
     def run_sampler(self):
         import dynesty
+
+        self._setup_pool()
         self.sampler = dynesty.DynamicNestedSampler(
-            loglikelihood=self.log_likelihood,
-            prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.sampler_init_kwargs)
+            loglikelihood=_log_likelihood_wrapper,
+            prior_transform=_prior_transform_wrapper,
+            ndim=self.ndim,
+            **self.sampler_init_kwargs,
+        )
 
         if self.check_point:
             out = self._run_external_sampler_with_checkpointing()
         else:
             out = self._run_external_sampler_without_checkpointing()
+        self._close_pool()
 
         # Flushes the output to force a line break
         if self.kwargs["verbose"]:
@@ -147,13 +209,14 @@ class DynamicDynesty(Dynesty):
         if self.resume:
             resume = self.read_saved_state(continuing=True)
             if resume:
-                logger.info('Resuming from previous run.')
+                logger.info("Resuming from previous run.")
 
         old_ncall = self.sampler.ncall
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxcall'] = self.n_check_point
+        sampler_kwargs["maxcall"] = self.n_check_point
+        self.start_time = datetime.datetime.now()
         while True:
-            sampler_kwargs['maxcall'] += self.n_check_point
+            sampler_kwargs["maxcall"] += self.n_check_point
             self.sampler.run_nested(**sampler_kwargs)
             if self.sampler.ncall == old_ncall:
                 break
@@ -164,27 +227,8 @@ class DynamicDynesty(Dynesty):
         self._remove_checkpoint()
         return self.sampler.results
 
-    def write_current_state(self):
-        """
-        """
-        import dill
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        with open(self.resume_file, 'wb') as file:
-            dill.dump(self, file)
-
-    def read_saved_state(self, continuing=False):
-        """
-        """
-        import dill
-
-        logger.debug("Reading resume file {}".format(self.resume_file))
-        if os.path.isfile(self.resume_file):
-            with open(self.resume_file, 'rb') as file:
-                self = dill.load(file)
-        else:
-            logger.debug(
-                "Failed to read resume file {}".format(self.resume_file))
-            return False
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        Sampler.write_current_state_and_exit(self=self, signum=signum, frame=frame)
 
     def _verify_kwargs_against_default_kwargs(self):
         Sampler._verify_kwargs_against_default_kwargs(self)
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index d9f1ae94bc4995c0b1f0bd88649ff3850e571085..ab0af61be6dc55ffac996931e4ed48a00cc348f1 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -1,62 +1,51 @@
 import datetime
 import os
 import sys
-import signal
 import time
 import warnings
 
 import numpy as np
 from pandas import DataFrame
 
+from ..result import rejection_sample
 from ..utils import (
-    logger,
     check_directory_exists_and_if_not_mkdir,
+    latex_plot_format,
+    logger,
     reflect,
     safe_file_dump,
-    latex_plot_format,
 )
-from .base_sampler import Sampler, NestedSampler
-from ..result import rejection_sample
-
-_likelihood = None
-_priors = None
-_search_parameter_keys = None
-_use_ratio = False
-
-
-def _initialize_global_variables(
-        likelihood, priors, search_parameter_keys, use_ratio
-):
-    """
-    Store a global copy of the likelihood, priors, and search keys for
-    multiprocessing.
-    """
-    global _likelihood
-    global _priors
-    global _search_parameter_keys
-    global _use_ratio
-    _likelihood = likelihood
-    _priors = priors
-    _search_parameter_keys = search_parameter_keys
-    _use_ratio = use_ratio
+from .base_sampler import NestedSampler, Sampler, signal_wrapper
 
 
 def _prior_transform_wrapper(theta):
     """Wrapper to the prior transformation. Needed for multiprocessing."""
-    return _priors.rescale(_search_parameter_keys, theta)
+    from .base_sampler import _sampling_convenience_dump
+
+    return _sampling_convenience_dump.priors.rescale(
+        _sampling_convenience_dump.search_parameter_keys, theta
+    )
 
 
 def _log_likelihood_wrapper(theta):
     """Wrapper to the log likelihood. Needed for multiprocessing."""
-    if _priors.evaluate_constraints({
-        key: theta[ii] for ii, key in enumerate(_search_parameter_keys)
-    }):
-        params = {key: t for key, t in zip(_search_parameter_keys, theta)}
-        _likelihood.parameters.update(params)
-        if _use_ratio:
-            return _likelihood.log_likelihood_ratio()
+    from .base_sampler import _sampling_convenience_dump
+
+    if _sampling_convenience_dump.priors.evaluate_constraints(
+        {
+            key: theta[ii]
+            for ii, key in enumerate(_sampling_convenience_dump.search_parameter_keys)
+        }
+    ):
+        params = {
+            key: t
+            for key, t in zip(_sampling_convenience_dump.search_parameter_keys, theta)
+        }
+        _sampling_convenience_dump.likelihood.parameters.update(params)
+        if _sampling_convenience_dump.use_ratio:
+            return _sampling_convenience_dump.likelihood.log_likelihood_ratio()
         else:
-            return _likelihood.log_likelihood()
+            return _sampling_convenience_dump.likelihood.log_likelihood()
     else:
         return np.nan_to_num(-np.inf)
 
@@ -130,32 +119,77 @@ class Dynesty(NestedSampler):
           e.g., 'interval-10' prints every ten seconds, this does not print every iteration
         - else: print to `stdout` at every iteration
     """
-    default_kwargs = dict(bound='multi', sample='rwalk',
-                          periodic=None, reflective=None,
-                          check_point_delta_t=1800, nlive=1000,
-                          first_update=None, walks=100,
-                          npdim=None, rstate=None, queue_size=1, pool=None,
-                          use_pool=None, live_points=None,
-                          logl_args=None, logl_kwargs=None,
-                          ptform_args=None, ptform_kwargs=None,
-                          enlarge=1.5, bootstrap=None, vol_dec=0.5, vol_check=8.0,
-                          facc=0.2, slices=5,
-                          update_interval=None, print_func=None,
-                          dlogz=0.1, maxiter=None, maxcall=None,
-                          logl_max=np.inf, add_live=True, print_progress=True,
-                          save_bounds=False, n_effective=None,
-                          maxmcmc=5000, nact=5, print_method="tqdm")
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 check_point=True, check_point_plot=True, n_check_point=None,
-                 check_point_delta_t=600, resume=True, nestcheck=False, exit_code=130, **kwargs):
-
-        super(Dynesty, self).__init__(likelihood=likelihood, priors=priors,
-                                      outdir=outdir, label=label, use_ratio=use_ratio,
-                                      plot=plot, skip_import_verification=skip_import_verification,
-                                      exit_code=exit_code,
-                                      **kwargs)
+
+    default_kwargs = dict(
+        bound="multi",
+        sample="rwalk",
+        print_progress=True,
+        periodic=None,
+        reflective=None,
+        check_point_delta_t=1800,
+        nlive=1000,
+        first_update=None,
+        walks=100,
+        npdim=None,
+        rstate=None,
+        queue_size=1,
+        pool=None,
+        use_pool=None,
+        live_points=None,
+        logl_args=None,
+        logl_kwargs=None,
+        ptform_args=None,
+        ptform_kwargs=None,
+        enlarge=1.5,
+        bootstrap=None,
+        vol_dec=0.5,
+        vol_check=8.0,
+        facc=0.2,
+        slices=5,
+        update_interval=None,
+        print_func=None,
+        dlogz=0.1,
+        maxiter=None,
+        maxcall=None,
+        logl_max=np.inf,
+        add_live=True,
+        save_bounds=False,
+        n_effective=None,
+        maxmcmc=5000,
+        nact=5,
+        print_method="tqdm",
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        check_point=True,
+        check_point_plot=True,
+        n_check_point=None,
+        check_point_delta_t=600,
+        resume=True,
+        nestcheck=False,
+        exit_code=130,
+        **kwargs,
+    ):
+        self._translate_kwargs(kwargs)
+        super(Dynesty, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            exit_code=exit_code,
+            **kwargs,
+        )
         self.n_check_point = n_check_point
         self.check_point = check_point
         self.check_point_plot = check_point_plot
@@ -169,77 +203,79 @@ class Dynesty(NestedSampler):
         if self.n_check_point is None:
             self.n_check_point = 1000
         self.check_point_delta_t = check_point_delta_t
-        logger.info("Checkpoint every check_point_delta_t = {}s"
-                    .format(check_point_delta_t))
+        logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s")
 
-        self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label)
+        self.resume_file = f"{self.outdir}/{self.label}_resume.pickle"
         self.sampling_time = datetime.timedelta()
 
-        try:
-            signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-            signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-            signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-        except AttributeError:
-            logger.debug(
-                "Setting signal attributes unavailable on this system. "
-                "This is likely the case if you are running on a Windows machine"
-                " and is no further concern.")
-
     def __getstate__(self):
-        """ For pickle: remove external_sampler, which can be an unpicklable "module" """
+        """For pickle: remove external_sampler, which can be an unpicklable "module" """
         state = self.__dict__.copy()
         if "external_sampler" in state:
-            del state['external_sampler']
+            del state["external_sampler"]
         return state
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
-                'maxcall', 'logl_max', 'add_live', 'save_bounds',
-                'n_effective']
+        keys = [
+            "dlogz",
+            "print_progress",
+            "print_func",
+            "maxiter",
+            "maxcall",
+            "logl_max",
+            "add_live",
+            "save_bounds",
+            "n_effective",
+        ]
         return {key: self.kwargs[key] for key in keys}
 
     @property
     def sampler_init_kwargs(self):
-        return {key: value
-                for key, value in self.kwargs.items()
-                if key not in self.sampler_function_kwargs}
+        return {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+        }
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'print_progress' not in kwargs:
-            if 'verbose' in kwargs:
-                kwargs['print_progress'] = kwargs.pop('verbose')
-        if 'walks' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "print_progress" not in kwargs:
+            if "verbose" in kwargs:
+                kwargs["print_progress"] = kwargs.pop("verbose")
+        if "walks" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['walks'] = kwargs.pop(equiv)
+                    kwargs["walks"] = kwargs.pop(equiv)
         if "queue_size" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['queue_size'] = kwargs.pop(equiv)
+                    kwargs["queue_size"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         from tqdm.auto import tqdm
-        if not self.kwargs['walks']:
-            self.kwargs['walks'] = 100
-        if not self.kwargs['update_interval']:
-            self.kwargs['update_interval'] = int(0.6 * self.kwargs['nlive'])
-        if self.kwargs['print_func'] is None:
-            self.kwargs['print_func'] = self._print_func
+
+        if not self.kwargs["walks"]:
+            self.kwargs["walks"] = 100
+        if not self.kwargs["update_interval"]:
+            self.kwargs["update_interval"] = int(0.6 * self.kwargs["nlive"])
+        if self.kwargs["print_func"] is None:
+            self.kwargs["print_func"] = self._print_func
             print_method = self.kwargs["print_method"]
             if print_method == "tqdm" and self.kwargs["print_progress"]:
                 self.pbar = tqdm(file=sys.stdout)
             elif "interval" in print_method:
                 self._last_print_time = datetime.datetime.now()
-                self._print_interval = datetime.timedelta(seconds=float(print_method.split("-")[1]))
+                self._print_interval = datetime.timedelta(
+                    seconds=float(print_method.split("-")[1])
+                )
         Sampler._verify_kwargs_against_default_kwargs(self)
 
     def _print_func(self, results, niter, ncall=None, dlogz=None, *args, **kwargs):
-        """ Replacing status update for dynesty.result.print_func """
+        """Replacing status update for dynesty.result.print_func"""
         if "interval" in self.kwargs["print_method"]:
             _time = datetime.datetime.now()
             if _time - self._last_print_time < self._print_interval:
@@ -251,17 +287,31 @@ class Dynesty(NestedSampler):
                 total_time = self.sampling_time + _time - self.start_time
 
                 # Remove fractional seconds
-                total_time_str = str(total_time).split('.')[0]
+                total_time_str = str(total_time).split(".")[0]
 
         # Extract results at the current iteration.
-        (worst, ustar, vstar, loglstar, logvol, logwt,
-         logz, logzvar, h, nc, worst_it, boundidx, bounditer,
-         eff, delta_logz) = results
+        (
+            worst,
+            ustar,
+            vstar,
+            loglstar,
+            logvol,
+            logwt,
+            logz,
+            logzvar,
+            h,
+            nc,
+            worst_it,
+            boundidx,
+            bounditer,
+            eff,
+            delta_logz,
+        ) = results
 
         # Adjusting outputs for printing.
         if delta_logz > 1e6:
             delta_logz = np.inf
-        if 0. <= logzvar <= 1e6:
+        if 0.0 <= logzvar <= 1e6:
             logzerr = np.sqrt(logzvar)
         else:
             logzerr = np.nan
@@ -271,38 +321,38 @@ class Dynesty(NestedSampler):
             loglstar = -np.inf
 
         if self.use_ratio:
-            key = 'logz-ratio'
+            key = "logz-ratio"
         else:
-            key = 'logz'
+            key = "logz"
 
         # Constructing output.
         string = list()
-        string.append("bound:{:d}".format(bounditer))
-        string.append("nc:{:3d}".format(nc))
-        string.append("ncall:{:.1e}".format(ncall))
-        string.append("eff:{:0.1f}%".format(eff))
-        string.append("{}={:0.2f}+/-{:0.2f}".format(key, logz, logzerr))
-        string.append("dlogz:{:0.3f}>{:0.2g}".format(delta_logz, dlogz))
+        string.append(f"bound:{bounditer:d}")
+        string.append(f"nc:{nc:3d}")
+        string.append(f"ncall:{ncall:.1e}")
+        string.append(f"eff:{eff:0.1f}%")
+        string.append(f"{key}={logz:0.2f}+/-{logzerr:0.2f}")
+        string.append(f"dlogz:{delta_logz:0.3f}>{dlogz:0.2g}")
 
         if self.kwargs["print_method"] == "tqdm":
             self.pbar.set_postfix_str(" ".join(string), refresh=False)
             self.pbar.update(niter - self.pbar.n)
         elif "interval" in self.kwargs["print_method"]:
             formatted = " ".join([total_time_str] + string)
-            print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True)
+            print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True)
         else:
             formatted = " ".join([total_time_str] + string)
-            print("{}it [{}]".format(niter, formatted), file=sys.stdout, flush=True)
+            print(f"{niter}it [{formatted}]", file=sys.stdout, flush=True)
 
     def _apply_dynesty_boundaries(self):
         self._periodic = list()
         self._reflective = list()
         for ii, key in enumerate(self.search_parameter_keys):
-            if self.priors[key].boundary == 'periodic':
-                logger.debug("Setting periodic boundary for {}".format(key))
+            if self.priors[key].boundary == "periodic":
+                logger.debug(f"Setting periodic boundary for {key}")
                 self._periodic.append(ii)
-            elif self.priors[key].boundary == 'reflective':
-                logger.debug("Setting reflective boundary for {}".format(key))
+            elif self.priors[key].boundary == "reflective":
+                logger.debug(f"Setting reflective boundary for {key}")
                 self._reflective.append(ii)
 
         # The periodic kwargs passed into dynesty allows the parameters to
@@ -312,61 +362,26 @@ class Dynesty(NestedSampler):
         self.kwargs["reflective"] = self._reflective
 
     def nestcheck_data(self, out_file):
-        import nestcheck.data_processing
         import pickle
+
+        import nestcheck.data_processing
+
         ns_run = nestcheck.data_processing.process_dynesty_run(out_file)
-        nestcheck_result = "{}/{}_nestcheck.pickle".format(self.outdir, self.label)
-        with open(nestcheck_result, 'wb') as file_nest:
+        nestcheck_result = f"{self.outdir}/{self.label}_nestcheck.pickle"
+        with open(nestcheck_result, "wb") as file_nest:
             pickle.dump(ns_run, file_nest)
 
-    def _setup_pool(self):
-        if self.kwargs["pool"] is not None:
-            logger.info("Using user defined pool.")
-            self.pool = self.kwargs["pool"]
-        elif self.kwargs["queue_size"] > 1:
-            logger.info(
-                "Setting up multiproccesing pool with {} processes.".format(
-                    self.kwargs["queue_size"]
-                )
-            )
-            import multiprocessing
-            self.pool = multiprocessing.Pool(
-                processes=self.kwargs["queue_size"],
-                initializer=_initialize_global_variables,
-                initargs=(
-                    self.likelihood,
-                    self.priors,
-                    self._search_parameter_keys,
-                    self.use_ratio
-                )
-            )
-        else:
-            _initialize_global_variables(
-                likelihood=self.likelihood,
-                priors=self.priors,
-                search_parameter_keys=self._search_parameter_keys,
-                use_ratio=self.use_ratio
-            )
-            self.pool = None
-        self.kwargs["pool"] = self.pool
-
-    def _close_pool(self):
-        if getattr(self, "pool", None) is not None:
-            logger.info("Starting to close worker pool.")
-            self.pool.close()
-            self.pool.join()
-            self.pool = None
-            self.kwargs["pool"] = self.pool
-            logger.info("Finished closing worker pool.")
-
+    @signal_wrapper
     def run_sampler(self):
-        import dynesty
         import dill
-        logger.info("Using dynesty version {}".format(dynesty.__version__))
+        import dynesty
+
+        logger.info(f"Using dynesty version {dynesty.__version__}")
 
         if self.kwargs.get("sample", "rwalk") == "rwalk":
             logger.info(
-                "Using the bilby-implemented rwalk sample method with ACT estimated walks")
+                "Using the bilby-implemented rwalk sample method with ACT estimated walks"
+            )
             dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
             dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
             if self.kwargs.get("walks") > self.kwargs.get("maxmcmc"):
@@ -375,12 +390,10 @@ class Dynesty(NestedSampler):
                 raise DynestySetupError("Unable to run with nact < 1")
         elif self.kwargs.get("sample") == "rwalk_dynesty":
             self._kwargs["sample"] = "rwalk"
-            logger.info(
-                "Using the dynesty-implemented rwalk sample method")
+            logger.info("Using the dynesty-implemented rwalk sample method")
         elif self.kwargs.get("sample") == "rstagger_dynesty":
             self._kwargs["sample"] = "rstagger"
-            logger.info(
-                "Using the dynesty-implemented rstagger sample method")
+            logger.info("Using the dynesty-implemented rstagger sample method")
 
         self._setup_pool()
 
@@ -388,22 +401,25 @@ class Dynesty(NestedSampler):
             self.resume = self.read_saved_state(continuing=True)
 
         if self.resume:
-            logger.info('Resume file successfully loaded.')
+            logger.info("Resume file successfully loaded.")
         else:
-            if self.kwargs['live_points'] is None:
-                self.kwargs['live_points'] = (
-                    self.get_initial_points_from_prior(self.kwargs['nlive'])
+            if self.kwargs["live_points"] is None:
+                self.kwargs["live_points"] = self.get_initial_points_from_prior(
+                    self.kwargs["nlive"]
                 )
             self.sampler = dynesty.NestedSampler(
                 loglikelihood=_log_likelihood_wrapper,
                 prior_transform=_prior_transform_wrapper,
-                ndim=self.ndim, **self.sampler_init_kwargs
+                ndim=self.ndim,
+                **self.sampler_init_kwargs,
             )
 
+        self.start_time = datetime.datetime.now()
         if self.check_point:
             out = self._run_external_sampler_with_checkpointing()
         else:
             out = self._run_external_sampler_without_checkpointing()
+        self._update_sampling_time()
 
         self._close_pool()
 
@@ -417,8 +433,8 @@ class Dynesty(NestedSampler):
         if self.nestcheck:
             self.nestcheck_data(out)
 
-        dynesty_result = "{}/{}_dynesty.pickle".format(self.outdir, self.label)
-        with open(dynesty_result, 'wb') as file:
+        dynesty_result = f"{self.outdir}/{self.label}_dynesty.pickle"
+        with open(dynesty_result, "wb") as file:
             dill.dump(out, file)
 
         self._generate_result(out)
@@ -432,21 +448,23 @@ class Dynesty(NestedSampler):
     def _generate_result(self, out):
         import dynesty
         from scipy.special import logsumexp
+
         logwts = out["logwt"]
-        weights = np.exp(logwts - out['logz'][-1])
-        nested_samples = DataFrame(
-            out.samples, columns=self.search_parameter_keys)
-        nested_samples['weights'] = weights
-        nested_samples['log_likelihood'] = out.logl
+        weights = np.exp(logwts - out["logz"][-1])
+        nested_samples = DataFrame(out.samples, columns=self.search_parameter_keys)
+        nested_samples["weights"] = weights
+        nested_samples["log_likelihood"] = out.logl
         self.result.samples = dynesty.utils.resample_equal(out.samples, weights)
         self.result.nested_samples = nested_samples
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
-            unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
-            sorted_samples=self.result.samples)
+            unsorted_loglikelihoods=out.logl,
+            unsorted_samples=out.samples,
+            sorted_samples=self.result.samples,
+        )
         self.result.log_evidence = out.logz[-1]
         self.result.log_evidence_err = out.logzerr[-1]
         self.result.information_gain = out.information[-1]
-        self.result.num_likelihood_evaluations = getattr(self.sampler, 'ncall', 0)
+        self.result.num_likelihood_evaluations = getattr(self.sampler, "ncall", 0)
 
         logneff = logsumexp(logwts) * 2 - logsumexp(logwts * 2)
         neffsamples = int(np.exp(logneff))
@@ -454,11 +472,16 @@ class Dynesty(NestedSampler):
             nlikelihood=self.result.num_likelihood_evaluations,
             neffsamples=neffsamples,
             sampling_time_s=self.sampling_time.seconds,
-            ncores=self.kwargs.get("queue_size", 1)
+            ncores=self.kwargs.get("queue_size", 1),
         )
 
+    def _update_sampling_time(self):
+        end_time = datetime.datetime.now()
+        self.sampling_time += end_time - self.start_time
+        self.start_time = end_time
+
     def _run_nested_wrapper(self, kwargs):
-        """ Wrapper function to run_nested
+        """Wrapper function to run_nested
 
         This wrapper catches exceptions related to different versions of
         dynesty accepting different arguments.
@@ -469,8 +492,7 @@ class Dynesty(NestedSampler):
             The dictionary of kwargs to pass to run_nested
 
         """
-        logger.debug("Calling run_nested with sampler_function_kwargs {}"
-                     .format(kwargs))
+        logger.debug(f"Calling run_nested with sampler_function_kwargs {kwargs}")
         try:
             self.sampler.run_nested(**kwargs)
         except TypeError:
@@ -487,9 +509,8 @@ class Dynesty(NestedSampler):
 
         old_ncall = self.sampler.ncall
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxcall'] = self.n_check_point
-        sampler_kwargs['add_live'] = True
-        self.start_time = datetime.datetime.now()
+        sampler_kwargs["maxcall"] = self.n_check_point
+        sampler_kwargs["add_live"] = True
         while True:
             self._run_nested_wrapper(sampler_kwargs)
             if self.sampler.ncall == old_ncall:
@@ -499,14 +520,16 @@ class Dynesty(NestedSampler):
             if os.path.isfile(self.resume_file):
                 last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file)
             else:
-                last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds()
+                last_checkpoint_s = (
+                    datetime.datetime.now() - self.start_time
+                ).total_seconds()
             if last_checkpoint_s > self.check_point_delta_t:
                 self.write_current_state()
                 self.plot_current_state()
             if self.sampler.added_live:
                 self.sampler._remove_live_points()
 
-        sampler_kwargs['add_live'] = True
+        sampler_kwargs["add_live"] = True
         self._run_nested_wrapper(sampler_kwargs)
         self.write_current_state()
         self.plot_current_state()
@@ -534,39 +557,41 @@ class Dynesty(NestedSampler):
             Whether the run is continuing or terminating, if True, the loaded
             state is mostly written back to disk.
         """
-        from ... import __version__ as bilby_version
-        from dynesty import __version__ as dynesty_version
         import dill
+        from dynesty import __version__ as dynesty_version
+
+        from ... import __version__ as bilby_version
+
         versions = dict(bilby=bilby_version, dynesty=dynesty_version)
         if os.path.isfile(self.resume_file):
-            logger.info("Reading resume file {}".format(self.resume_file))
-            with open(self.resume_file, 'rb') as file:
+            logger.info(f"Reading resume file {self.resume_file}")
+            with open(self.resume_file, "rb") as file:
                 sampler = dill.load(file)
 
                 if not hasattr(sampler, "versions"):
                     logger.warning(
-                        "The resume file {} is corrupted or the version of "
-                        "bilby has changed between runs. This resume file will "
-                        "be ignored."
-                        .format(self.resume_file)
+                        f"The resume file {self.resume_file} is corrupted or "
+                        "the version of bilby has changed between runs. This "
+                        "resume file will be ignored."
                     )
                     return False
                 version_warning = (
                     "The {code} version has changed between runs. "
                     "This may cause unpredictable behaviour and/or failure. "
                     "Old version = {old}, new version = {new}."
-
                 )
                 for code in versions:
                     if not versions[code] == sampler.versions.get(code, None):
-                        logger.warning(version_warning.format(
-                            code=code,
-                            old=sampler.versions.get(code, "None"),
-                            new=versions[code]
-                        ))
+                        logger.warning(
+                            version_warning.format(
+                                code=code,
+                                old=sampler.versions.get(code, "None"),
+                                new=versions[code],
+                            )
+                        )
                 del sampler.versions
                 self.sampler = sampler
-                if self.sampler.added_live and continuing:
+                if getattr(self.sampler, "added_live", False) and continuing:
                     self.sampler._remove_live_points()
                 self.sampler.nqueue = -1
                 self.sampler.rstate = np.random
@@ -579,27 +604,13 @@ class Dynesty(NestedSampler):
                     self.sampler.M = map
             return True
         else:
-            logger.info(
-                "Resume file {} does not exist.".format(self.resume_file))
+            logger.info(f"Resume file {self.resume_file} does not exist.")
             return False
 
     def write_current_state_and_exit(self, signum=None, frame=None):
-        """
-        Make sure that if a pool of jobs is running only the parent tries to
-        checkpoint and exit. Only the parent has a 'pool' attribute.
-        """
-        if self.kwargs["queue_size"] == 1 or getattr(self, "pool", None) is not None:
-            if signum == 14:
-                logger.info(
-                    "Run interrupted by alarm signal {}: checkpoint and exit on {}"
-                    .format(signum, self.exit_code))
-            else:
-                logger.info(
-                    "Run interrupted by signal {}: checkpoint and exit on {}"
-                    .format(signum, self.exit_code))
-            self.write_current_state()
-            self._close_pool()
-            os._exit(self.exit_code)
+        if self.kwargs["print_method"] == "tqdm":
+            self.pbar.close()
+        super(Dynesty, self).write_current_state_and_exit(signum=signum, frame=frame)
 
     def write_current_state(self):
         """
@@ -613,29 +624,26 @@ class Dynesty(NestedSampler):
         normal running.
         """
 
-        from ... import __version__ as bilby_version
-        from dynesty import __version__ as dynesty_version
         import dill
+        from dynesty import __version__ as dynesty_version
+
+        from ... import __version__ as bilby_version
 
         if getattr(self, "sampler", None) is None:
             # Sampler not initialized, not able to write current state
             return
 
         check_directory_exists_and_if_not_mkdir(self.outdir)
-        end_time = datetime.datetime.now()
-        if hasattr(self, 'start_time'):
-            self.sampling_time += end_time - self.start_time
-            self.start_time = end_time
+        if hasattr(self, "start_time"):
+            self._update_sampling_time()
             self.sampler.kwargs["sampling_time"] = self.sampling_time
             self.sampler.kwargs["start_time"] = self.start_time
-        self.sampler.versions = dict(
-            bilby=bilby_version, dynesty=dynesty_version
-        )
+        self.sampler.versions = dict(bilby=bilby_version, dynesty=dynesty_version)
         self.sampler.pool = None
         self.sampler.M = map
         if dill.pickles(self.sampler):
             safe_file_dump(self.sampler, self.resume_file, dill)
-            logger.info("Written checkpoint file {}".format(self.resume_file))
+            logger.info(f"Written checkpoint file {self.resume_file}")
         else:
             logger.warning(
                 "Cannot write pickle resume file! "
@@ -657,86 +665,108 @@ class Dynesty(NestedSampler):
         if nsamples < 100:
             return
 
-        filename = "{}/{}_samples.dat".format(self.outdir, self.label)
-        logger.info("Writing {} current samples to {}".format(nsamples, filename))
+        filename = f"{self.outdir}/{self.label}_samples.dat"
+        logger.info(f"Writing {nsamples} current samples to {filename}")
 
         df = DataFrame(samples, columns=self.search_parameter_keys)
-        df.to_csv(filename, index=False, header=True, sep=' ')
+        df.to_csv(filename, index=False, header=True, sep=" ")
 
     def plot_current_state(self):
         import matplotlib.pyplot as plt
+
         if self.check_point_plot:
             import dynesty.plotting as dyplot
-            labels = [label.replace('_', ' ') for label in self.search_parameter_keys]
+
+            labels = [label.replace("_", " ") for label in self.search_parameter_keys]
             try:
-                filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_trace.png"
                 fig = dyplot.traceplot(self.sampler.results, labels=labels)[0]
                 fig.tight_layout()
                 fig.savefig(filename)
-            except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e:
+            except (
+                RuntimeError,
+                np.linalg.linalg.LinAlgError,
+                ValueError,
+                OverflowError,
+                Exception,
+            ) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty state plot at checkpoint')
+                logger.warning("Failed to create dynesty state plot at checkpoint")
             finally:
                 plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_trace_unit.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_trace_unit.png"
                 from copy import deepcopy
+
                 temp = deepcopy(self.sampler.results)
                 temp["samples"] = temp["samples_u"]
                 fig = dyplot.traceplot(temp, labels=labels)[0]
                 fig.tight_layout()
                 fig.savefig(filename)
-            except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError, OverflowError, Exception) as e:
+            except (
+                RuntimeError,
+                np.linalg.linalg.LinAlgError,
+                ValueError,
+                OverflowError,
+                Exception,
+            ) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty unit state plot at checkpoint')
+                logger.warning("Failed to create dynesty unit state plot at checkpoint")
             finally:
                 plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_run.png"
                 fig, axs = dyplot.runplot(
-                    self.sampler.results, logplot=False, use_math_text=False)
+                    self.sampler.results, logplot=False, use_math_text=False
+                )
                 fig.tight_layout()
                 plt.savefig(filename)
             except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty run plot at checkpoint')
+                logger.warning("Failed to create dynesty run plot at checkpoint")
             finally:
-                plt.close('all')
+                plt.close("all")
             try:
-                filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label)
+                filename = f"{self.outdir}/{self.label}_checkpoint_stats.png"
                 fig, axs = dynesty_stats_plot(self.sampler)
                 fig.tight_layout()
                 plt.savefig(filename)
             except (RuntimeError, ValueError) as e:
                 logger.warning(e)
-                logger.warning('Failed to create dynesty stats plot at checkpoint')
+                logger.warning("Failed to create dynesty stats plot at checkpoint")
             finally:
-                plt.close('all')
+                plt.close("all")
 
     def generate_trace_plots(self, dynesty_results):
         check_directory_exists_and_if_not_mkdir(self.outdir)
-        filename = '{}/{}_trace.png'.format(self.outdir, self.label)
-        logger.debug("Writing trace plot to {}".format(filename))
+        filename = f"{self.outdir}/{self.label}_trace.png"
+        logger.debug(f"Writing trace plot to {filename}")
         from dynesty import plotting as dyplot
-        fig, axes = dyplot.traceplot(dynesty_results,
-                                     labels=self.result.parameter_labels)
+
+        fig, axes = dyplot.traceplot(
+            dynesty_results, labels=self.result.parameter_labels
+        )
         fig.tight_layout()
         fig.savefig(filename)
 
     def _run_test(self):
         import dynesty
         import pandas as pd
+
         self.sampler = dynesty.NestedSampler(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.sampler_init_kwargs)
+            ndim=self.ndim,
+            **self.sampler_init_kwargs,
+        )
         sampler_kwargs = self.sampler_function_kwargs.copy()
-        sampler_kwargs['maxiter'] = 2
+        sampler_kwargs["maxiter"] = 2
 
         self.sampler.run_nested(**sampler_kwargs)
         N = 100
-        self.result.samples = pd.DataFrame(
-            self.priors.sample(N))[self.search_parameter_keys].values
+        self.result.samples = pd.DataFrame(self.priors.sample(N))[
+            self.search_parameter_keys
+        ].values
         self.result.nested_samples = self.result.samples
         self.result.log_likelihood_evaluations = np.ones(N)
         self.result.log_evidence = 1
@@ -745,7 +775,7 @@ class Dynesty(NestedSampler):
         return self.result
 
     def prior_transform(self, theta):
-        """ Prior transform method that is passed into the external sampler.
+        """Prior transform method that is passed into the external sampler.
         cube we map this back to [0, 1].
 
         Parameters
@@ -762,25 +792,24 @@ class Dynesty(NestedSampler):
 
 
 def sample_rwalk_bilby(args):
-    """ Modified bilby-implemented version of dynesty.sampling.sample_rwalk """
+    """Modified bilby-implemented version of dynesty.sampling.sample_rwalk"""
     from dynesty.utils import unitcheck
 
     # Unzipping.
-    (u, loglstar, axes, scale,
-     prior_transform, loglikelihood, kwargs) = args
+    (u, loglstar, axes, scale, prior_transform, loglikelihood, kwargs) = args
     rstate = np.random
 
     # Bounds
-    nonbounded = kwargs.get('nonbounded', None)
-    periodic = kwargs.get('periodic', None)
-    reflective = kwargs.get('reflective', None)
+    nonbounded = kwargs.get("nonbounded", None)
+    periodic = kwargs.get("periodic", None)
+    reflective = kwargs.get("reflective", None)
 
     # Setup.
     n = len(u)
-    walks = kwargs.get('walks', 100)  # minimum number of steps
-    maxmcmc = kwargs.get('maxmcmc', 5000)  # Maximum number of steps
-    nact = kwargs.get('nact', 5)  # Number of ACT
-    old_act = kwargs.get('old_act', walks)
+    walks = kwargs.get("walks", 100)  # minimum number of steps
+    maxmcmc = kwargs.get("maxmcmc", 5000)  # Maximum number of steps
+    nact = kwargs.get("nact", 5)  # Number of ACT
+    old_act = kwargs.get("old_act", walks)
 
     # Initialize internal variables
     accept = 0
@@ -848,19 +877,21 @@ def sample_rwalk_bilby(args):
         if accept + reject > walks:
             act = estimate_nmcmc(
                 accept_ratio=accept / (accept + reject + nfail),
-                old_act=old_act, maxmcmc=maxmcmc)
+                old_act=old_act,
+                maxmcmc=maxmcmc,
+            )
 
         # If we've taken too many likelihood evaluations then break
         if accept + reject > maxmcmc:
             warnings.warn(
-                "Hit maximum number of walks {} with accept={}, reject={}, "
-                "and nfail={} try increasing maxmcmc"
-                .format(maxmcmc, accept, reject, nfail))
+                f"Hit maximum number of walks {maxmcmc} with accept={accept},"
+                f" reject={reject}, and nfail={nfail} try increasing maxmcmc"
+            )
             break
 
     # If the act is finite, pick randomly from within the chain
-    if np.isfinite(act) and int(.5 * nact * act) < len(u_list):
-        idx = np.random.randint(int(.5 * nact * act), len(u_list))
+    if np.isfinite(act) and int(0.5 * nact * act) < len(u_list):
+        idx = np.random.randint(int(0.5 * nact * act), len(u_list))
         u = u_list[idx]
         v = v_list[idx]
         logl = logl_list[idx]
@@ -870,7 +901,7 @@ def sample_rwalk_bilby(args):
         v = prior_transform(u)
         logl = loglikelihood(v)
 
-    blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale}
+    blob = {"accept": accept, "reject": reject, "fail": nfail, "scale": scale}
     kwargs["old_act"] = act
 
     ncall = accept + reject
@@ -878,7 +909,7 @@ def sample_rwalk_bilby(args):
 
 
 def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
-    """ Estimate autocorrelation length of chain using acceptance fraction
+    """Estimate autocorrelation length of chain using acceptance fraction
 
     Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapted from CPNest:
 
@@ -905,9 +936,8 @@ def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
     if accept_ratio == 0.0:
         Nmcmc_exact = (1 + 1 / tau) * old_act
     else:
-        Nmcmc_exact = (
-            (1. - 1. / tau) * old_act +
-            (safety / tau) * (2. / accept_ratio - 1.)
+        Nmcmc_exact = (1.0 - 1.0 / tau) * old_act + (safety / tau) * (
+            2.0 / accept_ratio - 1.0
         )
         Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
     return max(safety, int(Nmcmc_exact))
@@ -943,7 +973,7 @@ def dynesty_stats_plot(sampler):
 
     fig, axs = plt.subplots(nrows=4, figsize=(8, 8))
     for ax, name in zip(axs, ["nc", "scale"]):
-        ax.plot(getattr(sampler, "saved_{}".format(name)), color="blue")
+        ax.plot(getattr(sampler, f"saved_{name}"), color="blue")
         ax.set_ylabel(name.title())
     lifetimes = np.arange(len(sampler.saved_it)) - sampler.saved_it
     axs[-2].set_ylabel("Lifetime")
@@ -951,9 +981,17 @@ def dynesty_stats_plot(sampler):
     burn = int(geom(p=1 / nlive).isf(1 / 2 / nlive))
     if len(sampler.saved_it) > burn + sampler.nlive:
         axs[-2].plot(np.arange(0, burn), lifetimes[:burn], color="grey")
-        axs[-2].plot(np.arange(burn, len(lifetimes) - nlive), lifetimes[burn: -nlive], color="blue")
-        axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
-        lifetimes = lifetimes[burn: -nlive]
+        axs[-2].plot(
+            np.arange(burn, len(lifetimes) - nlive),
+            lifetimes[burn:-nlive],
+            color="blue",
+        )
+        axs[-2].plot(
+            np.arange(len(lifetimes) - nlive, len(lifetimes)),
+            lifetimes[-nlive:],
+            color="red",
+        )
+        lifetimes = lifetimes[burn:-nlive]
         ks_result = ks_1samp(lifetimes, geom(p=1 / nlive).cdf)
         axs[-1].hist(
             lifetimes,
@@ -961,19 +999,25 @@ def dynesty_stats_plot(sampler):
             histtype="step",
             density=True,
             color="blue",
-            label=f"p value = {ks_result.pvalue:.3f}"
+            label=f"p value = {ks_result.pvalue:.3f}",
         )
         axs[-1].plot(
             np.arange(1, 6 * nlive),
             geom(p=1 / nlive).pmf(np.arange(1, 6 * nlive)),
-            color="red"
+            color="red",
         )
         axs[-1].set_xlim(0, 6 * nlive)
         axs[-1].legend()
         axs[-1].set_yscale("log")
     else:
-        axs[-2].plot(np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey")
-        axs[-2].plot(np.arange(len(lifetimes) - nlive, len(lifetimes)), lifetimes[-nlive:], color="red")
+        axs[-2].plot(
+            np.arange(0, len(lifetimes) - nlive), lifetimes[:-nlive], color="grey"
+        )
+        axs[-2].plot(
+            np.arange(len(lifetimes) - nlive, len(lifetimes)),
+            lifetimes[-nlive:],
+            color="red",
+        )
     axs[-2].set_yscale("log")
     axs[-2].set_xlabel("Iteration")
     axs[-1].set_xlabel("Lifetime")
diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index d976d391ed34f0871b3dc3121f6924ee395cee41..5afa169b8959faa4c8e1e2df3a3d7397819ecd2f 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -1,7 +1,5 @@
 import os
-import signal
 import shutil
-import sys
 from collections import namedtuple
 from distutils.version import LooseVersion
 from shutil import copyfile
@@ -9,8 +7,11 @@ from shutil import copyfile
 import numpy as np
 from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import MCMCSampler, SamplerError
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import MCMCSampler, SamplerError, signal_wrapper
+from .ptemcee import LikePriorEvaluator
+
+_evaluator = LikePriorEvaluator()
 
 
 class Emcee(MCMCSampler):
@@ -45,81 +46,110 @@ class Emcee(MCMCSampler):
     """
 
     default_kwargs = dict(
-        nwalkers=500, a=2, args=[], kwargs={}, postargs=None, pool=None,
-        live_dangerously=False, runtime_sortingfn=None, lnprob0=None,
-        rstate0=None, blobs0=None, iterations=100, thin=1, storechain=True,
-        mh_proposal=None)
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
-                 burn_in_act=3, verbose=True, **kwargs):
+        nwalkers=500,
+        a=2,
+        args=[],
+        kwargs={},
+        postargs=None,
+        pool=None,
+        live_dangerously=False,
+        runtime_sortingfn=None,
+        lnprob0=None,
+        rstate0=None,
+        blobs0=None,
+        iterations=100,
+        thin=1,
+        storechain=True,
+        mh_proposal=None,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        pos0=None,
+        nburn=None,
+        burn_in_fraction=0.25,
+        resume=True,
+        burn_in_act=3,
+        **kwargs,
+    ):
         import emcee
-        self.emcee = emcee
 
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+        if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"):
             self.prerelease = True
         else:
             self.prerelease = False
         super(Emcee, self).__init__(
-            likelihood=likelihood, priors=priors, outdir=outdir,
-            label=label, use_ratio=use_ratio, plot=plot,
-            skip_import_verification=skip_import_verification, **kwargs)
-        self.emcee = self._check_version()
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+        self._check_version()
         self.resume = resume
         self.pos0 = pos0
         self.nburn = nburn
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
-        self.verbose = verbose
-
-        signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
-        signal.signal(signal.SIGINT, self.checkpoint_and_exit)
+        self.verbose = kwargs.get("verbose", True)
 
     def _check_version(self):
         import emcee
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+
+        if LooseVersion(emcee.__version__) > LooseVersion("2.2.1"):
             self.prerelease = True
         else:
             self.prerelease = False
         return emcee
 
     def _translate_kwargs(self, kwargs):
-        if 'nwalkers' not in kwargs:
+        if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nwalkers'] = kwargs.pop(equiv)
-        if 'iterations' not in kwargs:
-            if 'nsteps' in kwargs:
-                kwargs['iterations'] = kwargs.pop('nsteps')
-        if 'threads' in kwargs:
-            if kwargs['threads'] != 1:
-                logger.warning("The 'threads' argument cannot be used for "
-                               "parallelisation. This run will proceed "
-                               "without parallelisation, but consider the use "
-                               "of an appropriate Pool object passed to the "
-                               "'pool' keyword.")
-                kwargs['threads'] = 1
+                    kwargs["nwalkers"] = kwargs.pop(equiv)
+        if "iterations" not in kwargs:
+            if "nsteps" in kwargs:
+                kwargs["iterations"] = kwargs.pop("nsteps")
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin',
-                'storechain', 'mh_proposal']
+        keys = [
+            "lnprob0",
+            "rstate0",
+            "blobs0",
+            "iterations",
+            "thin",
+            "storechain",
+            "mh_proposal",
+        ]
 
         # updated function keywords for emcee > v2.2.1
-        updatekeys = {'p0': 'initial_state',
-                      'lnprob0': 'log_prob0',
-                      'storechain': 'store'}
+        updatekeys = {
+            "p0": "initial_state",
+            "lnprob0": "log_prob0",
+            "storechain": "store",
+        }
 
         function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
-        function_kwargs['p0'] = self.pos0
+        function_kwargs["p0"] = self.pos0
 
         if self.prerelease:
-            if function_kwargs['mh_proposal'] is not None:
-                logger.warning("The 'mh_proposal' option is no longer used "
-                               "in emcee v{}, and will be ignored.".format(
-                                   self.emcee.__version__))
-            del function_kwargs['mh_proposal']
+            if function_kwargs["mh_proposal"] is not None:
+                logger.warning(
+                    "The 'mh_proposal' option is no longer used "
+                    "in emcee > 2.2.1, and will be ignored."
+                )
+            del function_kwargs["mh_proposal"]
 
             for key in updatekeys:
                 if updatekeys[key] not in function_kwargs:
@@ -131,37 +161,30 @@ class Emcee(MCMCSampler):
 
     @property
     def sampler_init_kwargs(self):
-        init_kwargs = {key: value
-                       for key, value in self.kwargs.items()
-                       if key not in self.sampler_function_kwargs}
+        init_kwargs = {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+        }
 
-        init_kwargs['lnpostfn'] = self.lnpostfn
-        init_kwargs['dim'] = self.ndim
+        init_kwargs["lnpostfn"] = _evaluator.call_emcee
+        init_kwargs["dim"] = self.ndim
 
         # updated init keywords for emcee > v2.2.1
-        updatekeys = {'dim': 'ndim',
-                      'lnpostfn': 'log_prob_fn'}
+        updatekeys = {"dim": "ndim", "lnpostfn": "log_prob_fn"}
 
         if self.prerelease:
             for key in updatekeys:
                 if key in init_kwargs:
                     init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
 
-            oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal']
+            oldfunckeys = ["p0", "lnprob0", "storechain", "mh_proposal"]
             for key in oldfunckeys:
                 if key in init_kwargs:
                     del init_kwargs[key]
 
         return init_kwargs
 
-    def lnpostfn(self, theta):
-        log_prior = self.log_prior(theta)
-        if np.isinf(log_prior):
-            return -np.inf, [np.nan, np.nan]
-        else:
-            log_likelihood = self.log_likelihood(theta)
-            return log_likelihood + log_prior, [log_likelihood, log_prior]
-
     @property
     def nburn(self):
         if type(self.__nburn) in [float, int]:
@@ -174,52 +197,54 @@ class Emcee(MCMCSampler):
     @nburn.setter
     def nburn(self, nburn):
         if isinstance(nburn, (float, int)):
-            if nburn > self.kwargs['iterations'] - 1:
-                raise ValueError('Number of burn-in samples must be smaller '
-                                 'than the total number of iterations')
+            if nburn > self.kwargs["iterations"] - 1:
+                raise ValueError(
+                    "Number of burn-in samples must be smaller "
+                    "than the total number of iterations"
+                )
 
         self.__nburn = nburn
 
     @property
     def nwalkers(self):
-        return self.kwargs['nwalkers']
+        return self.kwargs["nwalkers"]
 
     @property
     def nsteps(self):
-        return self.kwargs['iterations']
+        return self.kwargs["iterations"]
 
     @nsteps.setter
     def nsteps(self, nsteps):
-        self.kwargs['iterations'] = nsteps
+        self.kwargs["iterations"] = nsteps
 
     @property
     def stored_chain(self):
-        """ Read the stored zero-temperature chain data in from disk """
+        """Read the stored zero-temperature chain data in from disk"""
         return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
 
     @property
     def stored_samples(self):
-        """ Returns the samples stored on disk """
+        """Returns the samples stored on disk"""
         return self.stored_chain[self.search_parameter_keys]
 
     @property
     def stored_loglike(self):
-        """ Returns the log-likelihood stored on disk """
-        return self.stored_chain['log_l']
+        """Returns the log-likelihood stored on disk"""
+        return self.stored_chain["log_l"]
 
     @property
     def stored_logprior(self):
-        """ Returns the log-prior stored on disk """
-        return self.stored_chain['log_p']
+        """Returns the log-prior stored on disk"""
+        return self.stored_chain["log_p"]
 
     def _init_chain_file(self):
         with open(self.checkpoint_info.chain_file, "w+") as ff:
-            ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
-                '\t'.join(self.search_parameter_keys)))
+            search_keys_str = "\t".join(self.search_parameter_keys)
+            ff.write(f"walker\t{search_keys_str}\tlog_l\tlog_p\n")
 
     @property
     def checkpoint_info(self):
-        """ Defines various things related to checkpointing and storing data
+        """Defines various things related to checkpointing and storing data
 
         Returns
         =======
@@ -231,21 +256,25 @@ class Emcee(MCMCSampler):
 
         """
         out_dir = os.path.join(
-            self.outdir, '{}_{}'.format(self.__class__.__name__.lower(),
-                                        self.label))
+            self.outdir, f"{self.__class__.__name__.lower()}_{self.label}"
+        )
         check_directory_exists_and_if_not_mkdir(out_dir)
 
-        chain_file = os.path.join(out_dir, 'chain.dat')
-        sampler_file = os.path.join(out_dir, 'sampler.pickle')
-        chain_template =\
-            '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
+        chain_file = os.path.join(out_dir, "chain.dat")
+        sampler_file = os.path.join(out_dir, "sampler.pickle")
+        chain_template = (
+            "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n"
+        )
 
         CheckpointInfo = namedtuple(
-            'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
+            "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]
+        )
 
         checkpoint_info = CheckpointInfo(
-            sampler_file=sampler_file, chain_file=chain_file,
-            chain_template=chain_template)
+            sampler_file=sampler_file,
+            chain_file=chain_file,
+            chain_template=chain_template,
+        )
 
         return checkpoint_info
 
@@ -254,43 +283,48 @@ class Emcee(MCMCSampler):
         nsteps = self._previous_iterations
         return self.sampler.chain[:, :nsteps, :]
 
-    def checkpoint(self):
-        """ Writes a pickle file of the sampler to disk using dill """
+    def write_current_state(self):
+        """Writes a pickle file of the sampler to disk using dill"""
         import dill
-        logger.info("Checkpointing sampler to file {}"
-                    .format(self.checkpoint_info.sampler_file))
-        with open(self.checkpoint_info.sampler_file, 'wb') as f:
+
+        logger.info(
+            f"Checkpointing sampler to file {self.checkpoint_info.sampler_file}"
+        )
+        with open(self.checkpoint_info.sampler_file, "wb") as f:
             # Overwrites the stored sampler chain with one that is truncated
             # to only the completed steps
             self.sampler._chain = self.sampler_chain
+            _pool = self.sampler.pool
+            self.sampler.pool = None
             dill.dump(self._sampler, f)
-
-    def checkpoint_and_exit(self, signum, frame):
-        logger.info("Received signal {}".format(signum))
-        self.checkpoint()
-        sys.exit()
+            self.sampler.pool = _pool
 
     def _initialise_sampler(self):
-        self._sampler = self.emcee.EnsembleSampler(**self.sampler_init_kwargs)
+        from emcee import EnsembleSampler
+
+        self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
         self._init_chain_file()
 
     @property
     def sampler(self):
-        """ Returns the emcee sampler object
+        """Returns the emcee sampler object
 
         If, already initialized, returns the stored _sampler value. Otherwise,
         first checks if there is a pickle file from which to load. If there is
         not, then initialize the sampler and set the initial random draw
 
         """
-        if hasattr(self, '_sampler'):
+        if hasattr(self, "_sampler"):
             pass
         elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
             import dill
-            logger.info("Resuming run from checkpoint file {}"
-                        .format(self.checkpoint_info.sampler_file))
-            with open(self.checkpoint_info.sampler_file, 'rb') as f:
+
+            logger.info(
+                f"Resuming run from checkpoint file {self.checkpoint_info.sampler_file}"
+            )
+            with open(self.checkpoint_info.sampler_file, "rb") as f:
                 self._sampler = dill.load(f)
+                self._sampler.pool = self.pool
             self._set_pos0_for_resume()
         else:
             self._initialise_sampler()
@@ -299,7 +333,7 @@ class Emcee(MCMCSampler):
 
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
-        temp_chain_file = chain_file + '.temp'
+        temp_chain_file = chain_file + ".temp"
         if os.path.isfile(chain_file):
             copyfile(chain_file, temp_chain_file)
         if self.prerelease:
@@ -313,7 +347,7 @@ class Emcee(MCMCSampler):
 
     @property
     def _previous_iterations(self):
-        """ Returns the number of iterations that the sampler has saved
+        """Returns the number of iterations that the sampler has saved
 
         This is used when loading in a sampler from a pickle file to figure out
         how much of the run has already been completed
@@ -325,7 +359,8 @@ class Emcee(MCMCSampler):
 
     def _draw_pos0_from_prior(self):
         return np.array(
-            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
+            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+        )
 
     @property
     def _pos0_shape(self):
@@ -340,8 +375,7 @@ class Emcee(MCMCSampler):
                 self.pos0 = np.squeeze(self.pos0)
 
             if self.pos0.shape != self._pos0_shape:
-                raise ValueError(
-                    'Input pos0 should be of shape ndim, nwalkers')
+                raise ValueError("Input pos0 should be of shape ndim, nwalkers")
             logger.debug("Checking input pos0")
             for draw in self.pos0:
                 self.check_draw(draw)
@@ -352,38 +386,39 @@ class Emcee(MCMCSampler):
     def _set_pos0_for_resume(self):
         self.pos0 = self.sampler.chain[:, -1, :]
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         from tqdm.auto import tqdm
+
         sampler_function_kwargs = self.sampler_function_kwargs
-        iterations = sampler_function_kwargs.pop('iterations')
+        iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
 
         if self.prerelease:
-            sampler_function_kwargs['initial_state'] = self.pos0
+            sampler_function_kwargs["initial_state"] = self.pos0
         else:
-            sampler_function_kwargs['p0'] = self.pos0
+            sampler_function_kwargs["p0"] = self.pos0
 
         # main iteration loop
-        iterator = self.sampler.sample(
-            iterations=iterations, **sampler_function_kwargs
-        )
+        iterator = self.sampler.sample(iterations=iterations, **sampler_function_kwargs)
         if self.verbose:
             iterator = tqdm(iterator, total=iterations)
         for sample in iterator:
             self.write_chains_to_file(sample)
         if self.verbose:
             iterator.close()
-        self.checkpoint()
+        self.write_current_state()
 
         self.result.sampler_output = np.nan
-        self.calculate_autocorrelation(
-            self.sampler.chain.reshape((-1, self.ndim)))
+        self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
         self.print_nburn_logging_info()
 
         self._generate_result()
 
-        self.result.samples = self.sampler.chain[:, self.nburn:, :].reshape(
-            (-1, self.ndim))
+        self.result.samples = self.sampler.chain[:, self.nburn :, :].reshape(
+            (-1, self.ndim)
+        )
         self.result.walkers = self.sampler.chain
         return self.result
 
@@ -393,10 +428,11 @@ class Emcee(MCMCSampler):
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps` ({} < {}). Try increasing the "
-                "number of steps.".format(self.result.nburn, self.nsteps))
+                f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
+                " Try increasing the number of steps."
+            )
         blobs = np.array(self.sampler.blobs)
-        blobs_trimmed = blobs[self.nburn:, :, :].reshape((-1, 2))
+        blobs_trimmed = blobs[self.nburn :, :, :].reshape((-1, 2))
         log_likelihoods, log_priors = blobs_trimmed.T
         self.result.log_likelihood_evaluations = log_likelihoods
         self.result.log_prior_evaluations = log_priors
diff --git a/bilby/core/sampler/fake_sampler.py b/bilby/core/sampler/fake_sampler.py
index 8d218472d13f7769c4b88cdd0155cb8f7c04dd0d..5f375fdbad8055e6a5bdaf7dd7e99caabe330f33 100644
--- a/bilby/core/sampler/fake_sampler.py
+++ b/bilby/core/sampler/fake_sampler.py
@@ -1,8 +1,7 @@
-
 import numpy as np
 
-from .base_sampler import Sampler
 from ..result import read_in_result
+from .base_sampler import Sampler
 
 
 class FakeSampler(Sampler):
@@ -17,17 +16,38 @@ class FakeSampler(Sampler):
     sample_file: str
         A string pointing to the posterior data file to be loaded.
     """
-    default_kwargs = dict(verbose=True, logl_args=None, logl_kwargs=None,
-                          print_progress=True)
-
-    def __init__(self, likelihood, priors, sample_file, outdir='outdir',
-                 label='label', use_ratio=False, plot=False,
-                 injection_parameters=None, meta_data=None, result_class=None,
-                 **kwargs):
-        super(FakeSampler, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                          use_ratio=False, plot=False, skip_import_verification=True,
-                                          injection_parameters=None, meta_data=None, result_class=None,
-                                          **kwargs)
+
+    default_kwargs = dict(
+        verbose=True, logl_args=None, logl_kwargs=None, print_progress=True
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        sample_file,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        injection_parameters=None,
+        meta_data=None,
+        result_class=None,
+        **kwargs
+    ):
+        super(FakeSampler, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=False,
+            plot=False,
+            skip_import_verification=True,
+            injection_parameters=None,
+            meta_data=None,
+            result_class=None,
+            **kwargs
+        )
         self._read_parameter_list_from_file(sample_file)
         self.result.outdir = outdir
         self.result.label = label
@@ -41,7 +61,7 @@ class FakeSampler(Sampler):
 
     def run_sampler(self):
         """Compute the likelihood for the list of parameter space points."""
-        self.sampler = 'fake_sampler'
+        self.sampler = "fake_sampler"
 
         # Flushes the output to force a line break
         if self.kwargs["verbose"]:
@@ -59,8 +79,12 @@ class FakeSampler(Sampler):
             likelihood_ratios.append(logl)
 
             if self.kwargs["verbose"]:
-                print(self.likelihood.parameters['log_likelihood'], likelihood_ratios[-1],
-                      self.likelihood.parameters['log_likelihood'] - likelihood_ratios[-1])
+                print(
+                    self.likelihood.parameters["log_likelihood"],
+                    likelihood_ratios[-1],
+                    self.likelihood.parameters["log_likelihood"]
+                    - likelihood_ratios[-1],
+                )
 
         self.result.log_likelihood_evaluations = np.array(likelihood_ratios)
 
diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py
index 83947fc88378c5401508eac458192141cd9f221e..1f09387cc33520a7c8408db7cd7af2a924fa85cf 100644
--- a/bilby/core/sampler/kombine.py
+++ b/bilby/core/sampler/kombine.py
@@ -2,8 +2,12 @@ import os
 
 import numpy as np
 
-from .emcee import Emcee
 from ..utils import logger
+from .base_sampler import signal_wrapper
+from .emcee import Emcee
+from .ptemcee import LikePriorEvaluator
+
+_evaluator = LikePriorEvaluator()
 
 
 class Kombine(Emcee):
@@ -35,21 +39,61 @@ class Kombine(Emcee):
 
     """
 
-    default_kwargs = dict(nwalkers=500, args=[], pool=None, transd=False,
-                          lnpost0=None, blob0=None, iterations=500, storechain=True, processes=1, update_interval=None,
-                          kde=None, kde_size=None, spaces=None, freeze_transd=False, test_steps=16, critical_pval=0.05,
-                          max_steps=None, burnin_verbose=False)
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
-                 burn_in_act=3, autoburnin=False, **kwargs):
-        super(Kombine, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                      use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification,
-                                      pos0=pos0, nburn=nburn, burn_in_fraction=burn_in_fraction,
-                                      burn_in_act=burn_in_act, resume=resume, **kwargs)
-
-        if self.kwargs['nwalkers'] > self.kwargs['iterations']:
+    default_kwargs = dict(
+        nwalkers=500,
+        args=[],
+        pool=None,
+        transd=False,
+        lnpost0=None,
+        blob0=None,
+        iterations=500,
+        storechain=True,
+        processes=1,
+        update_interval=None,
+        kde=None,
+        kde_size=None,
+        spaces=None,
+        freeze_transd=False,
+        test_steps=16,
+        critical_pval=0.05,
+        max_steps=None,
+        burnin_verbose=False,
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        pos0=None,
+        nburn=None,
+        burn_in_fraction=0.25,
+        resume=True,
+        burn_in_act=3,
+        autoburnin=False,
+        **kwargs,
+    ):
+        super(Kombine, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            pos0=pos0,
+            nburn=nburn,
+            burn_in_fraction=burn_in_fraction,
+            burn_in_act=burn_in_act,
+            resume=resume,
+            **kwargs,
+        )
+
+        if self.kwargs["nwalkers"] > self.kwargs["iterations"]:
             raise ValueError("Kombine Sampler requires Iterations be > nWalkers")
         self.autoburnin = autoburnin
 
@@ -57,42 +101,34 @@ class Kombine(Emcee):
         # set prerelease to False to prevent checks for newer emcee versions in parent class
         self.prerelease = False
 
-    def _translate_kwargs(self, kwargs):
-        if 'nwalkers' not in kwargs:
-            for equiv in self.nwalkers_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs['nwalkers'] = kwargs.pop(equiv)
-        if 'iterations' not in kwargs:
-            if 'nsteps' in kwargs:
-                kwargs['iterations'] = kwargs.pop('nsteps')
-        # make sure processes kwarg is 1
-        if 'processes' in kwargs:
-            if kwargs['processes'] != 1:
-                logger.warning("The 'processes' argument cannot be used for "
-                               "parallelisation. This run will proceed "
-                               "without parallelisation, but consider the use "
-                               "of an appropriate Pool object passed to the "
-                               "'pool' keyword.")
-                kwargs['processes'] = 1
-
     @property
     def sampler_function_kwargs(self):
-        keys = ['lnpost0', 'blob0', 'iterations', 'storechain', 'lnprop0', 'update_interval', 'kde',
-                'kde_size', 'spaces', 'freeze_transd']
+        keys = [
+            "lnpost0",
+            "blob0",
+            "iterations",
+            "storechain",
+            "lnprop0",
+            "update_interval",
+            "kde",
+            "kde_size",
+            "spaces",
+            "freeze_transd",
+        ]
         function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
-        function_kwargs['p0'] = self.pos0
+        function_kwargs["p0"] = self.pos0
         return function_kwargs
 
     @property
     def sampler_burnin_kwargs(self):
-        extra_keys = ['test_steps', 'critical_pval', 'max_steps', 'burnin_verbose']
-        removal_keys = ['iterations', 'spaces', 'freeze_transd']
+        extra_keys = ["test_steps", "critical_pval", "max_steps", "burnin_verbose"]
+        removal_keys = ["iterations", "spaces", "freeze_transd"]
         burnin_kwargs = self.sampler_function_kwargs.copy()
         for key in extra_keys:
             if key in self.kwargs:
                 burnin_kwargs[key] = self.kwargs[key]
-        if 'burnin_verbose' in burnin_kwargs.keys():
-            burnin_kwargs['verbose'] = burnin_kwargs.pop('burnin_verbose')
+        if "burnin_verbose" in burnin_kwargs.keys():
+            burnin_kwargs["verbose"] = burnin_kwargs.pop("burnin_verbose")
         for key in removal_keys:
             if key in burnin_kwargs.keys():
                 burnin_kwargs.pop(key)
@@ -100,19 +136,21 @@ class Kombine(Emcee):
 
     @property
     def sampler_init_kwargs(self):
-        init_kwargs = {key: value
-                       for key, value in self.kwargs.items()
-                       if key not in self.sampler_function_kwargs and key not in self.sampler_burnin_kwargs}
+        init_kwargs = {
+            key: value
+            for key, value in self.kwargs.items()
+            if key not in self.sampler_function_kwargs
+            and key not in self.sampler_burnin_kwargs
+        }
         init_kwargs.pop("burnin_verbose")
-        init_kwargs['lnpostfn'] = self.lnpostfn
-        init_kwargs['ndim'] = self.ndim
+        init_kwargs["lnpostfn"] = _evaluator.call_emcee
+        init_kwargs["ndim"] = self.ndim
 
-        # have to make sure pool is None so sampler will be pickleable
-        init_kwargs['pool'] = None
         return init_kwargs
 
     def _initialise_sampler(self):
         import kombine
+
         self._sampler = kombine.Sampler(**self.sampler_init_kwargs)
         self._init_chain_file()
 
@@ -129,7 +167,9 @@ class Kombine(Emcee):
     def check_resume(self):
         return self.resume and os.path.isfile(self.checkpoint_info.sampler_file)
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         if self.autoburnin:
             if self.check_resume():
                 logger.info("Resuming with autoburnin=True skips burnin process:")
@@ -138,29 +178,50 @@ class Kombine(Emcee):
                 self.sampler.burnin(**self.sampler_burnin_kwargs)
                 self.kwargs["iterations"] += self._previous_iterations
                 self.nburn = self._previous_iterations
-                logger.info("Kombine auto-burnin complete. Removing {} samples from chains".format(self.nburn))
+                logger.info(
+                    f"Kombine auto-burnin complete. Removing {self.nburn} samples from chains"
+                )
                 self._set_pos0_for_resume()
 
         from tqdm.auto import tqdm
+
         sampler_function_kwargs = self.sampler_function_kwargs
-        iterations = sampler_function_kwargs.pop('iterations')
+        iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
-        sampler_function_kwargs['p0'] = self.pos0
+        sampler_function_kwargs["p0"] = self.pos0
         for sample in tqdm(
-                self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
-                total=iterations):
+            self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
+            total=iterations,
+        ):
             self.write_chains_to_file(sample)
-        self.checkpoint()
+        self.write_current_state()
         self.result.sampler_output = np.nan
         if not self.autoburnin:
             tmp_chain = self.sampler.chain.copy()
             self.calculate_autocorrelation(tmp_chain.reshape((-1, self.ndim)))
             self.print_nburn_logging_info()
+        self._close_pool()
 
         self._generate_result()
         self.result.log_evidence_err = np.nan
 
-        tmp_chain = self.sampler.chain[self.nburn:, :, :].copy()
+        tmp_chain = self.sampler.chain[self.nburn :, :, :].copy()
         self.result.samples = tmp_chain.reshape((-1, self.ndim))
-        self.result.walkers = self.sampler.chain.reshape((self.nwalkers, self.nsteps, self.ndim))
+        self.result.walkers = self.sampler.chain.reshape(
+            (self.nwalkers, self.nsteps, self.ndim)
+        )
         return self.result
+
+    def _setup_pool(self):
+        from kombine import SerialPool
+
+        super(Kombine, self)._setup_pool()
+        if self.pool is None:
+            self.pool = SerialPool()
+
+    def _close_pool(self):
+        from kombine import SerialPool
+
+        if isinstance(self.pool, SerialPool):
+            self.pool = None
+        super(Kombine, self)._close_pool()
diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py
index d8bb578a4d681c8aca425cd0b4825c651dfa73e0..fdee87b058647a47a5671aed9a9ab14d23ad0573 100644
--- a/bilby/core/sampler/nessai.py
+++ b/bilby/core/sampler/nessai.py
@@ -1,9 +1,10 @@
-import numpy as np
 import os
+
+import numpy as np
 from pandas import DataFrame
 
+from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger
 from .base_sampler import NestedSampler
-from ..utils import logger, check_directory_exists_and_if_not_mkdir, load_json
 
 
 class Nessai(NestedSampler):
@@ -16,8 +17,9 @@ class Nessai(NestedSampler):
 
     Documentation: https://nessai.readthedocs.io/
     """
+
     _default_kwargs = None
-    seed_equiv_kwargs = ['sampling_seed']
+    seed_equiv_kwargs = ["sampling_seed"]
 
     @property
     def default_kwargs(self):
@@ -29,6 +31,7 @@ class Nessai(NestedSampler):
         """
         if not self._default_kwargs:
             from inspect import signature
+
             from nessai.flowsampler import FlowSampler
             from nessai.nestedsampler import NestedSampler
             from nessai.proposal import AugmentedFlowProposal, FlowProposal
@@ -42,12 +45,14 @@ class Nessai(NestedSampler):
             ]
             for c in classes:
                 kwargs.update(
-                    {k: v.default for k, v in signature(c).parameters.items() if v.default is not v.empty}
+                    {
+                        k: v.default
+                        for k, v in signature(c).parameters.items()
+                        if v.default is not v.empty
+                    }
                 )
             # Defaults for bilby that will override nessai defaults
-            bilby_defaults = dict(
-                output=None,
-            )
+            bilby_defaults = dict(output=None, exit_code=self.exit_code)
             kwargs.update(bilby_defaults)
             self._default_kwargs = kwargs
         return self._default_kwargs
@@ -69,8 +74,8 @@ class Nessai(NestedSampler):
 
     def run_sampler(self):
         from nessai.flowsampler import FlowSampler
-        from nessai.model import Model as BaseModel
         from nessai.livepoint import dict_to_live_points, live_points_to_array
+        from nessai.model import Model as BaseModel
         from nessai.posterior import compute_weights
         from nessai.utils import setup_logger
 
@@ -85,6 +90,7 @@ class Nessai(NestedSampler):
                 Priors to use for sampling. Needed for the bounds and the
                 `sample` method.
             """
+
             def __init__(self, names, priors):
                 self.names = names
                 self.priors = priors
@@ -103,8 +109,10 @@ class Nessai(NestedSampler):
                 return self.log_prior(theta)
 
             def _update_bounds(self):
-                self.bounds = {key: [self.priors[key].minimum, self.priors[key].maximum]
-                               for key in self.names}
+                self.bounds = {
+                    key: [self.priors[key].minimum, self.priors[key].maximum]
+                    for key in self.names
+                }
 
             def new_point(self, N=1):
                 """Draw a point from the prior"""
@@ -117,20 +125,22 @@ class Nessai(NestedSampler):
                 return self.log_prior(x)
 
         # Setup the logger for nessai using the same settings as the bilby logger
-        setup_logger(self.outdir, label=self.label,
-                     log_level=logger.getEffectiveLevel())
+        setup_logger(
+            self.outdir, label=self.label, log_level=logger.getEffectiveLevel()
+        )
         model = Model(self.search_parameter_keys, self.priors)
-        out = None
-        while out is None:
-            try:
-                out = FlowSampler(model, **self.kwargs)
-            except TypeError as e:
-                raise TypeError("Unable to initialise nessai sampler with error: {}".format(e))
         try:
+            out = FlowSampler(model, **self.kwargs)
             out.run(save=True, plot=self.plot)
-        except SystemExit as e:
+        except TypeError as e:
+            raise TypeError(f"Unable to initialise nessai sampler with error: {e}")
+        except (SystemExit, KeyboardInterrupt) as e:
             import sys
-            logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code))
+
+            logger.info(
+                f"Caught {type(e).__name__} with args {e.args}, "
+                f"exiting with signal {self.exit_code}"
+            )
             sys.exit(self.exit_code)
 
         # Manually set likelihood evaluations because parallelisation breaks the counter
@@ -139,53 +149,61 @@ class Nessai(NestedSampler):
         self.result.samples = live_points_to_array(
             out.posterior_samples, self.search_parameter_keys
         )
-        self.result.log_likelihood_evaluations = out.posterior_samples['logL']
+        self.result.log_likelihood_evaluations = out.posterior_samples["logL"]
         self.result.nested_samples = DataFrame(out.nested_samples)
         self.result.nested_samples.rename(
-            columns=dict(logL='log_likelihood', logP='log_prior'), inplace=True)
-        _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood),
-                                         np.array(out.ns.state.nlive))
-        self.result.nested_samples['weights'] = np.exp(log_weights)
+            columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True
+        )
+        _, log_weights = compute_weights(
+            np.array(self.result.nested_samples.log_likelihood),
+            np.array(out.ns.state.nlive),
+        )
+        self.result.nested_samples["weights"] = np.exp(log_weights)
         self.result.log_evidence = out.ns.log_evidence
         self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive)
 
         return self.result
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
-        if 'n_pool' not in kwargs:
+                    kwargs["nlive"] = kwargs.pop(equiv)
+        if "n_pool" not in kwargs:
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['n_pool'] = kwargs.pop(equiv)
-        if 'seed' not in kwargs:
+                    kwargs["n_pool"] = kwargs.pop(equiv)
+            if "n_pool" not in kwargs:
+                kwargs["n_pool"] = self._npool
+        if "seed" not in kwargs:
             for equiv in self.seed_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['seed'] = kwargs.pop(equiv)
+                    kwargs["seed"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
         """
         Set the directory where the output will be written
         and check resume and checkpoint status.
         """
-        if 'config_file' in self.kwargs:
-            d = load_json(self.kwargs['config_file'], None)
+        if "config_file" in self.kwargs:
+            d = load_json(self.kwargs["config_file"], None)
             self.kwargs.update(d)
-            self.kwargs.pop('config_file')
+            self.kwargs.pop("config_file")
 
-        if not self.kwargs['plot']:
-            self.kwargs['plot'] = self.plot
+        if not self.kwargs["plot"]:
+            self.kwargs["plot"] = self.plot
 
-        if self.kwargs['n_pool'] == 1 and self.kwargs['max_threads'] == 1:
-            logger.warning('Setting pool to None (n_pool=1 & max_threads=1)')
-            self.kwargs['n_pool'] = None
+        if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1:
+            logger.warning("Setting pool to None (n_pool=1 & max_threads=1)")
+            self.kwargs["n_pool"] = None
 
-        if not self.kwargs['output']:
-            self.kwargs['output'] = os.path.join(
-                self.outdir, '{}_nessai'.format(self.label), ''
+        if not self.kwargs["output"]:
+            self.kwargs["output"] = os.path.join(
+                self.outdir, f"{self.label}_nessai", ""
             )
 
-        check_directory_exists_and_if_not_mkdir(self.kwargs['output'])
+        check_directory_exists_and_if_not_mkdir(self.kwargs["output"])
         NestedSampler._verify_kwargs_against_default_kwargs(self)
+
+    def _setup_pool(self):
+        pass
diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py
index f598d8b1751b217ae9515019c5963001eb0da840..2ea8787a63a85b62101e6f0d193bae85765f3884 100644
--- a/bilby/core/sampler/nestle.py
+++ b/bilby/core/sampler/nestle.py
@@ -1,8 +1,7 @@
-
 import numpy as np
 from pandas import DataFrame
 
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, signal_wrapper
 
 
 class Nestle(NestedSampler):
@@ -25,30 +24,44 @@ class Nestle(NestedSampler):
         sampling
 
     """
-    default_kwargs = dict(verbose=True, method='multi', npoints=500,
-                          update_interval=None, npdim=None, maxiter=None,
-                          maxcall=None, dlogz=None, decline_factor=None,
-                          rstate=None, callback=None, steps=20, enlarge=1.2)
+
+    default_kwargs = dict(
+        verbose=True,
+        method="multi",
+        npoints=500,
+        update_interval=None,
+        npdim=None,
+        maxiter=None,
+        maxcall=None,
+        dlogz=None,
+        decline_factor=None,
+        rstate=None,
+        callback=None,
+        steps=20,
+        enlarge=1.2,
+    )
 
     def _translate_kwargs(self, kwargs):
-        if 'npoints' not in kwargs:
+        if "npoints" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['npoints'] = kwargs.pop(equiv)
-        if 'steps' not in kwargs:
+                    kwargs["npoints"] = kwargs.pop(equiv)
+        if "steps" not in kwargs:
             for equiv in self.walks_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['steps'] = kwargs.pop(equiv)
+                    kwargs["steps"] = kwargs.pop(equiv)
 
     def _verify_kwargs_against_default_kwargs(self):
-        if self.kwargs['verbose']:
+        if self.kwargs["verbose"]:
             import nestle
-            self.kwargs['callback'] = nestle.print_progress
-            self.kwargs.pop('verbose')
+
+            self.kwargs["callback"] = nestle.print_progress
+            self.kwargs.pop("verbose")
         NestedSampler._verify_kwargs_against_default_kwargs(self)
 
+    @signal_wrapper
     def run_sampler(self):
-        """ Runs Nestle sampler with given kwargs and returns the result
+        """Runs Nestle sampler with given kwargs and returns the result
 
         Returns
         =======
@@ -56,21 +69,27 @@ class Nestle(NestedSampler):
 
         """
         import nestle
+
         out = nestle.sample(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **self.kwargs)
+            ndim=self.ndim,
+            **self.kwargs
+        )
         print("")
 
         self.result.sampler_output = out
         self.result.samples = nestle.resample_equal(out.samples, out.weights)
         self.result.nested_samples = DataFrame(
-            out.samples, columns=self.search_parameter_keys)
-        self.result.nested_samples['weights'] = out.weights
-        self.result.nested_samples['log_likelihood'] = out.logl
+            out.samples, columns=self.search_parameter_keys
+        )
+        self.result.nested_samples["weights"] = out.weights
+        self.result.nested_samples["log_likelihood"] = out.logl
         self.result.log_likelihood_evaluations = self.reorder_loglikelihoods(
-            unsorted_loglikelihoods=out.logl, unsorted_samples=out.samples,
-            sorted_samples=self.result.samples)
+            unsorted_loglikelihoods=out.logl,
+            unsorted_samples=out.samples,
+            sorted_samples=self.result.samples,
+        )
         self.result.log_evidence = out.logz
         self.result.log_evidence_err = out.logzerr
         self.result.information_gain = out.h
@@ -88,14 +107,24 @@ class Nestle(NestedSampler):
 
         """
         import nestle
+
         kwargs = self.kwargs.copy()
-        kwargs['maxiter'] = 2
+        kwargs["maxiter"] = 2
         nestle.sample(
             loglikelihood=self.log_likelihood,
             prior_transform=self.prior_transform,
-            ndim=self.ndim, **kwargs)
+            ndim=self.ndim,
+            **kwargs
+        )
         self.result.samples = np.random.uniform(0, 1, (100, self.ndim))
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
         self.calc_likelihood_count()
         return self.result
+
+    def write_current_state(self):
+        """
+        Nestle doesn't support checkpointing so no current state will be
+        written on interrupt.
+        """
+        pass
diff --git a/bilby/core/sampler/polychord.py b/bilby/core/sampler/polychord.py
index 943a5c413abe7e45ff54eb4dde2c9aa8d35b7d91..617d6c7d17b22569f1b9bc23e37e58e292033625 100644
--- a/bilby/core/sampler/polychord.py
+++ b/bilby/core/sampler/polychord.py
@@ -1,7 +1,6 @@
-
 import numpy as np
 
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, signal_wrapper
 
 
 class PyPolyChord(NestedSampler):
@@ -21,32 +20,66 @@ class PyPolyChord(NestedSampler):
     To see what the keyword arguments are for, see the docstring of PyPolyChordSettings
     """
 
-    default_kwargs = dict(use_polychord_defaults=False, nlive=None, num_repeats=None,
-                          nprior=-1, do_clustering=True, feedback=1, precision_criterion=0.001,
-                          logzero=-1e30, max_ndead=-1, boost_posterior=0.0, posteriors=True,
-                          equals=True, cluster_posteriors=True, write_resume=True,
-                          write_paramnames=False, read_resume=True, write_stats=True,
-                          write_live=True, write_dead=True, write_prior=True,
-                          compression_factor=np.exp(-1), base_dir='outdir',
-                          file_root='polychord', seed=-1, grade_dims=None, grade_frac=None, nlives={})
-
+    default_kwargs = dict(
+        use_polychord_defaults=False,
+        nlive=None,
+        num_repeats=None,
+        nprior=-1,
+        do_clustering=True,
+        feedback=1,
+        precision_criterion=0.001,
+        logzero=-1e30,
+        max_ndead=-1,
+        boost_posterior=0.0,
+        posteriors=True,
+        equals=True,
+        cluster_posteriors=True,
+        write_resume=True,
+        write_paramnames=False,
+        read_resume=True,
+        write_stats=True,
+        write_live=True,
+        write_dead=True,
+        write_prior=True,
+        compression_factor=np.exp(-1),
+        base_dir="outdir",
+        file_root="polychord",
+        seed=-1,
+        grade_dims=None,
+        grade_frac=None,
+        nlives={},
+    )
+    hard_exit = True
+
+    @signal_wrapper
     def run_sampler(self):
         import pypolychord
         from pypolychord.settings import PolyChordSettings
-        if self.kwargs['use_polychord_defaults']:
-            settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim,
-                                         base_dir=self._sample_file_directory,
-                                         file_root=self.label)
+
+        if self.kwargs["use_polychord_defaults"]:
+            settings = PolyChordSettings(
+                nDims=self.ndim,
+                nDerived=self.ndim,
+                base_dir=self._sample_file_directory,
+                file_root=self.label,
+            )
         else:
             self._setup_dynamic_defaults()
             pc_kwargs = self.kwargs.copy()
-            pc_kwargs['base_dir'] = self._sample_file_directory
-            pc_kwargs['file_root'] = self.label
-            pc_kwargs.pop('use_polychord_defaults')
-            settings = PolyChordSettings(nDims=self.ndim, nDerived=self.ndim, **pc_kwargs)
+            pc_kwargs["base_dir"] = self._sample_file_directory
+            pc_kwargs["file_root"] = self.label
+            pc_kwargs.pop("use_polychord_defaults")
+            settings = PolyChordSettings(
+                nDims=self.ndim, nDerived=self.ndim, **pc_kwargs
+            )
         self._verify_kwargs_against_default_kwargs()
-        out = pypolychord.run_polychord(loglikelihood=self.log_likelihood, nDims=self.ndim,
-                                        nDerived=self.ndim, settings=settings, prior=self.prior_transform)
+        out = pypolychord.run_polychord(
+            loglikelihood=self.log_likelihood,
+            nDims=self.ndim,
+            nDerived=self.ndim,
+            settings=settings,
+            prior=self.prior_transform,
+        )
         self.result.log_evidence = out.logZ
         self.result.log_evidence_err = out.logZerr
         log_likelihoods, physical_parameters = self._read_sample_file()
@@ -56,24 +89,24 @@ class PyPolyChord(NestedSampler):
         return self.result
 
     def _setup_dynamic_defaults(self):
-        """ Sets up some interdependent default argument if none are given by the user """
-        if not self.kwargs['grade_dims']:
-            self.kwargs['grade_dims'] = [self.ndim]
-        if not self.kwargs['grade_frac']:
-            self.kwargs['grade_frac'] = [1.0] * len(self.kwargs['grade_dims'])
-        if not self.kwargs['nlive']:
-            self.kwargs['nlive'] = self.ndim * 25
-        if not self.kwargs['num_repeats']:
-            self.kwargs['num_repeats'] = self.ndim * 5
+        """Sets up some interdependent default argument if none are given by the user"""
+        if not self.kwargs["grade_dims"]:
+            self.kwargs["grade_dims"] = [self.ndim]
+        if not self.kwargs["grade_frac"]:
+            self.kwargs["grade_frac"] = [1.0] * len(self.kwargs["grade_dims"])
+        if not self.kwargs["nlive"]:
+            self.kwargs["nlive"] = self.ndim * 25
+        if not self.kwargs["num_repeats"]:
+            self.kwargs["num_repeats"] = self.ndim * 5
 
     def _translate_kwargs(self, kwargs):
-        if 'nlive' not in kwargs:
+        if "nlive" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['nlive'] = kwargs.pop(equiv)
+                    kwargs["nlive"] = kwargs.pop(equiv)
 
     def log_likelihood(self, theta):
-        """ Overrides the log_likelihood so that PolyChord understands it """
+        """Overrides the log_likelihood so that PolyChord understands it"""
         return super(PyPolyChord, self).log_likelihood(theta), theta
 
     def _read_sample_file(self):
@@ -87,12 +120,14 @@ class PyPolyChord(NestedSampler):
         array_like, array_like: The log_likelihoods and the associated parameters
 
         """
-        sample_file = self._sample_file_directory + '/' + self.label + '_equal_weights.txt'
+        sample_file = (
+            self._sample_file_directory + "/" + self.label + "_equal_weights.txt"
+        )
         samples = np.loadtxt(sample_file)
         log_likelihoods = -0.5 * samples[:, 1]
-        physical_parameters = samples[:, -self.ndim:]
+        physical_parameters = samples[:, -self.ndim :]
         return log_likelihoods, physical_parameters
 
     @property
     def _sample_file_directory(self):
-        return self.outdir + '/chains'
+        return self.outdir + "/chains"
diff --git a/bilby/core/sampler/proposal.py b/bilby/core/sampler/proposal.py
index 2d52616588328c8b5bdb02247d20ff5d48b71e8b..023caac5744de968d338cea30125574248b91f95 100644
--- a/bilby/core/sampler/proposal.py
+++ b/bilby/core/sampler/proposal.py
@@ -1,13 +1,12 @@
+import random
 from inspect import isclass
 
 import numpy as np
-import random
 
 from ..prior import Uniform
 
 
 class Sample(dict):
-
     def __init__(self, dictionary=None):
         if dictionary is None:
             dictionary = dict()
@@ -31,15 +30,14 @@ class Sample(dict):
 
     @classmethod
     def from_external_type(cls, external_sample, sampler_name):
-        if sampler_name == 'cpnest':
+        if sampler_name == "cpnest":
             return cls.from_cpnest_live_point(external_sample)
         return external_sample
 
 
 class JumpProposal(object):
-
     def __init__(self, priors=None):
-        """ A generic class for jump proposals
+        """A generic class for jump proposals
 
         Parameters
         ==========
@@ -56,7 +54,7 @@ class JumpProposal(object):
         self.log_j = 0.0
 
     def __call__(self, sample, **kwargs):
-        """ A generic wrapper for the jump proposal function
+        """A generic wrapper for the jump proposal function
 
         Parameters
         ==========
@@ -71,26 +69,35 @@ class JumpProposal(object):
         return self._apply_boundaries(sample)
 
     def _move_reflecting_keys(self, sample):
-        keys = [key for key in sample.keys()
-                if self.priors[key].boundary == 'reflective']
+        keys = [
+            key for key in sample.keys() if self.priors[key].boundary == "reflective"
+        ]
         for key in keys:
-            if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum:
+            if (
+                sample[key] > self.priors[key].maximum
+                or sample[key] < self.priors[key].minimum
+            ):
                 r = self.priors[key].maximum - self.priors[key].minimum
                 delta = (sample[key] - self.priors[key].minimum) % (2 * r)
                 if delta > r:
-                    sample[key] = 2 * self.priors[key].maximum - self.priors[key].minimum - delta
+                    sample[key] = (
+                        2 * self.priors[key].maximum - self.priors[key].minimum - delta
+                    )
                 elif delta < r:
                     sample[key] = self.priors[key].minimum + delta
         return sample
 
     def _move_periodic_keys(self, sample):
-        keys = [key for key in sample.keys()
-                if self.priors[key].boundary == 'periodic']
+        keys = [key for key in sample.keys() if self.priors[key].boundary == "periodic"]
         for key in keys:
-            if sample[key] > self.priors[key].maximum or sample[key] < self.priors[key].minimum:
-                sample[key] = (self.priors[key].minimum +
-                               ((sample[key] - self.priors[key].minimum) %
-                                (self.priors[key].maximum - self.priors[key].minimum)))
+            if (
+                sample[key] > self.priors[key].maximum
+                or sample[key] < self.priors[key].minimum
+            ):
+                sample[key] = self.priors[key].minimum + (
+                    (sample[key] - self.priors[key].minimum)
+                    % (self.priors[key].maximum - self.priors[key].minimum)
+                )
         return sample
 
     def _apply_boundaries(self, sample):
@@ -100,9 +107,8 @@ class JumpProposal(object):
 
 
 class JumpProposalCycle(object):
-
     def __init__(self, proposal_functions, weights, cycle_length=100):
-        """ A generic wrapper class for proposal cycles
+        """A generic wrapper class for proposal cycles
 
         Parameters
         ==========
@@ -129,8 +135,12 @@ class JumpProposalCycle(object):
         return len(self.proposal_functions)
 
     def update_cycle(self):
-        self._cycle = np.random.choice(self.proposal_functions, size=self.cycle_length,
-                                       p=self.weights, replace=True)
+        self._cycle = np.random.choice(
+            self.proposal_functions,
+            size=self.cycle_length,
+            p=self.weights,
+            replace=True,
+        )
 
     @property
     def proposal_functions(self):
@@ -190,9 +200,13 @@ class NormJump(JumpProposal):
 
 
 class EnsembleWalk(JumpProposal):
-
-    def __init__(self, random_number_generator=random.random, n_points=3, priors=None,
-                 **random_number_generator_args):
+    def __init__(
+        self,
+        random_number_generator=random.random,
+        n_points=3,
+        priors=None,
+        **random_number_generator_args
+    ):
         """
         An ensemble walk
 
@@ -213,12 +227,16 @@ class EnsembleWalk(JumpProposal):
         self.random_number_generator_args = random_number_generator_args
 
     def __call__(self, sample, **kwargs):
-        subset = random.sample(kwargs['coordinates'], self.n_points)
+        subset = random.sample(kwargs["coordinates"], self.n_points)
         for i in range(len(subset)):
-            subset[i] = Sample.from_external_type(subset[i], kwargs.get('sampler_name', None))
+            subset[i] = Sample.from_external_type(
+                subset[i], kwargs.get("sampler_name", None)
+            )
         center_of_mass = self.get_center_of_mass(subset)
         for x in subset:
-            sample += (x - center_of_mass) * self.random_number_generator(**self.random_number_generator_args)
+            sample += (x - center_of_mass) * self.random_number_generator(
+                **self.random_number_generator_args
+            )
         return super(EnsembleWalk, self).__call__(sample)
 
     @staticmethod
@@ -227,7 +245,6 @@ class EnsembleWalk(JumpProposal):
 
 
 class EnsembleStretch(JumpProposal):
-
     def __init__(self, scale=2.0, priors=None):
         """
         Stretch move. Calculates the log Jacobian which can be used in cpnest to bias future moves.
@@ -241,8 +258,10 @@ class EnsembleStretch(JumpProposal):
         self.scale = scale
 
     def __call__(self, sample, **kwargs):
-        second_sample = random.choice(kwargs['coordinates'])
-        second_sample = Sample.from_external_type(second_sample, kwargs.get('sampler_name', None))
+        second_sample = random.choice(kwargs["coordinates"])
+        second_sample = Sample.from_external_type(
+            second_sample, kwargs.get("sampler_name", None)
+        )
         step = random.uniform(-1, 1) * np.log(self.scale)
         sample = second_sample + (sample - second_sample) * np.exp(step)
         self.log_j = len(sample) * step
@@ -250,7 +269,6 @@ class EnsembleStretch(JumpProposal):
 
 
 class DifferentialEvolution(JumpProposal):
-
     def __init__(self, sigma=1e-4, mu=1.0, priors=None):
         """
         Differential evolution step. Takes two elements from the existing coordinates and differentially evolves the
@@ -268,13 +286,12 @@ class DifferentialEvolution(JumpProposal):
         self.mu = mu
 
     def __call__(self, sample, **kwargs):
-        a, b = random.sample(kwargs['coordinates'], 2)
+        a, b = random.sample(kwargs["coordinates"], 2)
         sample = sample + (b - a) * random.gauss(self.mu, self.sigma)
         return super(DifferentialEvolution, self).__call__(sample)
 
 
 class EnsembleEigenVector(JumpProposal):
-
     def __init__(self, priors=None):
         """
         Ensemble step based on the ensemble eigenvectors.
@@ -316,7 +333,7 @@ class EnsembleEigenVector(JumpProposal):
         self.eigen_values, self.eigen_vectors = np.linalg.eigh(self.covariance)
 
     def __call__(self, sample, **kwargs):
-        self.update_eigenvectors(kwargs['coordinates'])
+        self.update_eigenvectors(kwargs["coordinates"])
         i = random.randrange(len(sample))
         jump_size = np.sqrt(np.fabs(self.eigen_values[i])) * random.gauss(0, 1)
         for j, key in enumerate(sample.keys()):
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 6191af0ec51ccbb984ef034d317f26b544cb8e9b..063e2af7e472812e0a9a8f88cc2ea7c1fe384cc0 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -2,17 +2,19 @@ import copy
 import datetime
 import logging
 import os
-import signal
-import sys
 import time
 from collections import namedtuple
 
 import numpy as np
 import pandas as pd
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import SamplerError, MCMCSampler
-
+from ..utils import check_directory_exists_and_if_not_mkdir, logger
+from .base_sampler import (
+    MCMCSampler,
+    SamplerError,
+    _sampling_convenience_dump,
+    signal_wrapper,
+)
 
 ConvergenceInputs = namedtuple(
     "ConvergenceInputs",
@@ -81,7 +83,7 @@ class Ptemcee(MCMCSampler):
         the Gelman-Rubin statistic).
     min_tau: int, (1)
         A minimum tau (autocorrelation time) to accept.
-    check_point_deltaT: float, (600)
+    check_point_delta_t: float, (600)
         The period with which to checkpoint (in seconds).
     threads: int, (1)
         If threads > 1, a MultiPool object is setup and used.
@@ -163,7 +165,7 @@ class Ptemcee(MCMCSampler):
         gradient_mean_log_posterior=0.1,
         Q_tol=1.02,
         min_tau=1,
-        check_point_deltaT=600,
+        check_point_delta_t=600,
         threads=1,
         exit_code=77,
         plot=False,
@@ -173,7 +175,7 @@ class Ptemcee(MCMCSampler):
         niterations_per_check=5,
         log10beta_min=None,
         verbose=True,
-        **kwargs
+        **kwargs,
     ):
         super(Ptemcee, self).__init__(
             likelihood=likelihood,
@@ -184,25 +186,18 @@ class Ptemcee(MCMCSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
-            **kwargs
+            **kwargs,
         )
 
         self.nwalkers = self.sampler_init_kwargs["nwalkers"]
         self.ntemps = self.sampler_init_kwargs["ntemps"]
         self.max_steps = 500
 
-        # Setup up signal handling
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
         # Checkpointing inputs
         self.resume = resume
-        self.check_point_deltaT = check_point_deltaT
+        self.check_point_delta_t = check_point_delta_t
         self.check_point_plot = check_point_plot
-        self.resume_file = "{}/{}_checkpoint_resume.pickle".format(
-            self.outdir, self.label
-        )
+        self.resume_file = f"{self.outdir}/{self.label}_checkpoint_resume.pickle"
 
         # Store convergence checking inputs in a named tuple
         convergence_inputs_dict = dict(
@@ -223,7 +218,7 @@ class Ptemcee(MCMCSampler):
             niterations_per_check=niterations_per_check,
         )
         self.convergence_inputs = ConvergenceInputs(**convergence_inputs_dict)
-        logger.info("Using convergence inputs: {}".format(self.convergence_inputs))
+        logger.info(f"Using convergence inputs: {self.convergence_inputs}")
 
         # Check if threads was given as an equivalent arg
         if threads == 1:
@@ -239,32 +234,50 @@ class Ptemcee(MCMCSampler):
         self.pos0 = pos0
 
         self._periodic = [
-            self.priors[key].boundary == "periodic" for key in self.search_parameter_keys
+            self.priors[key].boundary == "periodic"
+            for key in self.search_parameter_keys
         ]
         self.priors.sample()
-        self._minima = np.array([
-            self.priors[key].minimum for key in self.search_parameter_keys
-        ])
-        self._range = np.array([
-            self.priors[key].maximum for key in self.search_parameter_keys
-        ]) - self._minima
+        self._minima = np.array(
+            [self.priors[key].minimum for key in self.search_parameter_keys]
+        )
+        self._range = (
+            np.array([self.priors[key].maximum for key in self.search_parameter_keys])
+            - self._minima
+        )
 
         self.log10beta_min = log10beta_min
         if self.log10beta_min is not None:
             betas = np.logspace(0, self.log10beta_min, self.ntemps)
-            logger.warning("Using betas {}".format(betas))
+            logger.warning(f"Using betas {betas}")
             self.kwargs["betas"] = betas
         self.verbose = verbose
 
+        self.iteration = 0
+        self.chain_array = self.get_zero_chain_array()
+        self.log_likelihood_array = self.get_zero_array()
+        self.log_posterior_array = self.get_zero_array()
+        self.beta_list = list()
+        self.tau_list = list()
+        self.tau_list_n = list()
+        self.Q_list = list()
+        self.time_per_check = list()
+
+        self.nburn = np.nan
+        self.thin = np.nan
+        self.tau_int = np.nan
+        self.nsamples_effective = 0
+        self.discard = 0
+
     @property
     def sampler_function_kwargs(self):
-        """ Kwargs passed to samper.sampler() """
+        """Kwargs passed to samper.sampler()"""
         keys = ["adapt", "swap_ratios"]
         return {key: self.kwargs[key] for key in keys}
 
     @property
     def sampler_init_kwargs(self):
-        """ Kwargs passed to initialize ptemcee.Sampler() """
+        """Kwargs passed to initialize ptemcee.Sampler()"""
         return {
             key: value
             for key, value in self.kwargs.items()
@@ -272,14 +285,14 @@ class Ptemcee(MCMCSampler):
         }
 
     def _translate_kwargs(self, kwargs):
-        """ Translate kwargs """
+        """Translate kwargs"""
         if "nwalkers" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["nwalkers"] = kwargs.pop(equiv)
 
     def get_pos0_from_prior(self):
-        """ Draw the initial positions from the prior
+        """Draw the initial positions from the prior
 
         Returns
         =======
@@ -288,16 +301,15 @@ class Ptemcee(MCMCSampler):
 
         """
         logger.info("Generating pos0 samples")
-        return np.array([
+        return np.array(
             [
-                self.get_random_draw_from_prior()
-                for _ in range(self.nwalkers)
+                [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+                for _ in range(self.kwargs["ntemps"])
             ]
-            for _ in range(self.kwargs["ntemps"])
-        ])
+        )
 
     def get_pos0_from_minimize(self, minimize_list=None):
-        """ Draw the initial positions using an initial minimization step
+        """Draw the initial positions using an initial minimization step
 
         See pos0 in the class initialization for details.
 
@@ -318,12 +330,12 @@ class Ptemcee(MCMCSampler):
         else:
             pos0 = np.array(self.get_pos0_from_prior())
 
-        logger.info("Attempting to set pos0 for {} from minimize".format(minimize_list))
+        logger.info(f"Attempting to set pos0 for {minimize_list} from minimize")
 
         likelihood_copy = copy.copy(self.likelihood)
 
         def neg_log_like(params):
-            """ Internal function to minimize """
+            """Internal function to minimize"""
             likelihood_copy.parameters.update(
                 {key: val for key, val in zip(minimize_list, params)}
             )
@@ -360,9 +372,7 @@ class Ptemcee(MCMCSampler):
         for i, key in enumerate(minimize_list):
             pos0_min = np.min(success[:, i])
             pos0_max = np.max(success[:, i])
-            logger.info(
-                "Initialize {} walkers from {}->{}".format(key, pos0_min, pos0_max)
-            )
+            logger.info(f"Initialize {key} walkers from {pos0_min}->{pos0_max}")
             j = self.search_parameter_keys.index(key)
             pos0[:, :, j] = np.random.uniform(
                 pos0_min,
@@ -375,9 +385,8 @@ class Ptemcee(MCMCSampler):
         if self.pos0.shape != (self.ntemps, self.nwalkers, self.ndim):
             raise ValueError(
                 "Shape of starting array should be (ntemps, nwalkers, ndim). "
-                "In this case that is ({}, {}, {}), got {}".format(
-                    self.ntemps, self.nwalkers, self.ndim, self.pos0.shape
-                )
+                f"In this case that is ({self.ntemps}, {self.nwalkers}, "
+                f"{self.ndim}), got {self.pos0.shape}"
             )
         else:
             return self.pos0
@@ -395,12 +404,13 @@ class Ptemcee(MCMCSampler):
         return self.get_pos0_from_array()
 
     def setup_sampler(self):
-        """ Either initialize the sampler or read in the resume file """
+        """Either initialize the sampler or read in the resume file"""
         import ptemcee
 
         if os.path.isfile(self.resume_file) and self.resume is True:
             import dill
-            logger.info("Resume data {} found".format(self.resume_file))
+
+            logger.info(f"Resume data {self.resume_file} found")
             with open(self.resume_file, "rb") as file:
                 data = dill.load(file)
 
@@ -422,9 +432,7 @@ class Ptemcee(MCMCSampler):
             self.sampler.pool = self.pool
             self.sampler.threads = self.threads
 
-            logger.info(
-                "Resuming from previous run with time={}".format(self.iteration)
-            )
+            logger.info(f"Resuming from previous run with time={self.iteration}")
 
         else:
             # Initialize the PTSampler
@@ -433,32 +441,29 @@ class Ptemcee(MCMCSampler):
                     dim=self.ndim,
                     logl=self.log_likelihood,
                     logp=self.log_prior,
-                    **self.sampler_init_kwargs
+                    **self.sampler_init_kwargs,
                 )
             else:
                 self.sampler = ptemcee.Sampler(
                     dim=self.ndim,
                     logl=do_nothing_function,
                     logp=do_nothing_function,
-                    pool=self.pool,
                     threads=self.threads,
-                    **self.sampler_init_kwargs
+                    **self.sampler_init_kwargs,
                 )
 
-                self.sampler._likeprior = LikePriorEvaluator(
-                    self.search_parameter_keys, use_ratio=self.use_ratio
-                )
+            self.sampler._likeprior = LikePriorEvaluator()
 
             # Initialize storing results
             self.iteration = 0
             self.chain_array = self.get_zero_chain_array()
             self.log_likelihood_array = self.get_zero_array()
             self.log_posterior_array = self.get_zero_array()
-            self.beta_list = []
-            self.tau_list = []
-            self.tau_list_n = []
-            self.Q_list = []
-            self.time_per_check = []
+            self.beta_list = list()
+            self.tau_list = list()
+            self.tau_list_n = list()
+            self.Q_list = list()
+            self.time_per_check = list()
             self.pos0 = self.get_pos0()
 
         return self.sampler
@@ -470,7 +475,7 @@ class Ptemcee(MCMCSampler):
         return np.zeros((self.ntemps, self.nwalkers, self.max_steps))
 
     def get_pos0(self):
-        """ Master logic for setting pos0 """
+        """Master logic for setting pos0"""
         if isinstance(self.pos0, str) and self.pos0.lower() == "prior":
             return self.get_pos0_from_prior()
         elif isinstance(self.pos0, str) and self.pos0.lower() == "minimize":
@@ -482,52 +487,55 @@ class Ptemcee(MCMCSampler):
         elif isinstance(self.pos0, dict):
             return self.get_pos0_from_dict()
         else:
-            raise SamplerError("pos0={} not implemented".format(self.pos0))
+            raise SamplerError(f"pos0={self.pos0} not implemented")
 
-    def setup_pool(self):
-        """ If threads > 1, setup a MultiPool, else run in serial mode """
-        if self.threads > 1:
-            import schwimmbad
-
-            logger.info("Creating MultiPool with {} processes".format(self.threads))
-            self.pool = schwimmbad.MultiPool(
-                self.threads, initializer=init, initargs=(self.likelihood, self.priors)
-            )
-        else:
-            self.pool = None
+    def _close_pool(self):
+        if getattr(self.sampler, "pool", None) is not None:
+            self.sampler.pool = None
+        if "pool" in self.result.sampler_kwargs:
+            del self.result.sampler_kwargs["pool"]
+        super(Ptemcee, self)._close_pool()
 
+    @signal_wrapper
     def run_sampler(self):
-        self.setup_pool()
+        self._setup_pool()
         sampler = self.setup_sampler()
 
         t0 = datetime.datetime.now()
         logger.info("Starting to sample")
         while True:
             for (pos0, log_posterior, log_likelihood) in sampler.sample(
-                    self.pos0, storechain=False,
-                    iterations=self.convergence_inputs.niterations_per_check,
-                    **self.sampler_function_kwargs):
-                pos0[:, :, self._periodic] = np.mod(
-                    pos0[:, :, self._periodic] - self._minima[self._periodic],
-                    self._range[self._periodic]
-                ) + self._minima[self._periodic]
+                self.pos0,
+                storechain=False,
+                iterations=self.convergence_inputs.niterations_per_check,
+                **self.sampler_function_kwargs,
+            ):
+                pos0[:, :, self._periodic] = (
+                    np.mod(
+                        pos0[:, :, self._periodic] - self._minima[self._periodic],
+                        self._range[self._periodic],
+                    )
+                    + self._minima[self._periodic]
+                )
 
             if self.iteration == self.chain_array.shape[1]:
-                self.chain_array = np.concatenate((
-                    self.chain_array, self.get_zero_chain_array()), axis=1)
-                self.log_likelihood_array = np.concatenate((
-                    self.log_likelihood_array, self.get_zero_array()),
-                    axis=2)
-                self.log_posterior_array = np.concatenate((
-                    self.log_posterior_array, self.get_zero_array()),
-                    axis=2)
+                self.chain_array = np.concatenate(
+                    (self.chain_array, self.get_zero_chain_array()), axis=1
+                )
+                self.log_likelihood_array = np.concatenate(
+                    (self.log_likelihood_array, self.get_zero_array()), axis=2
+                )
+                self.log_posterior_array = np.concatenate(
+                    (self.log_posterior_array, self.get_zero_array()), axis=2
+                )
 
             self.pos0 = pos0
             self.chain_array[:, self.iteration, :] = pos0[0, :, :]
             self.log_likelihood_array[:, :, self.iteration] = log_likelihood
             self.log_posterior_array[:, :, self.iteration] = log_posterior
             self.mean_log_posterior = np.mean(
-                self.log_posterior_array[:, :, :self. iteration], axis=1)
+                self.log_posterior_array[:, :, : self.iteration], axis=1
+            )
 
             # Calculate time per iteration
             self.time_per_check.append((datetime.datetime.now() - t0).total_seconds())
@@ -537,15 +545,13 @@ class Ptemcee(MCMCSampler):
 
             # Calculate minimum iteration step to discard
             minimum_iteration = get_minimum_stable_itertion(
-                self.mean_log_posterior,
-                frac=self.convergence_inputs.mean_logl_frac
+                self.mean_log_posterior, frac=self.convergence_inputs.mean_logl_frac
             )
-            logger.debug("Minimum iteration = {}".format(minimum_iteration))
+            logger.debug(f"Minimum iteration = {minimum_iteration}")
 
             # Calculate the maximum discard number
             discard_max = np.max(
-                [self.convergence_inputs.burn_in_fixed_discard,
-                 minimum_iteration]
+                [self.convergence_inputs.burn_in_fixed_discard, minimum_iteration]
             )
 
             if self.iteration > discard_max + self.nwalkers:
@@ -565,7 +571,7 @@ class Ptemcee(MCMCSampler):
                 self.nsamples_effective,
             ) = check_iteration(
                 self.iteration,
-                self.chain_array[:, self.discard:self.iteration, :],
+                self.chain_array[:, self.discard : self.iteration, :],
                 sampler,
                 self.convergence_inputs,
                 self.search_parameter_keys,
@@ -588,7 +594,7 @@ class Ptemcee(MCMCSampler):
             else:
                 last_checkpoint_s = np.sum(self.time_per_check)
 
-            if last_checkpoint_s > self.check_point_deltaT:
+            if last_checkpoint_s > self.check_point_delta_t:
                 self.write_current_state(plot=self.check_point_plot)
 
         # Run a final checkpoint to update the plots and samples
@@ -609,9 +615,14 @@ class Ptemcee(MCMCSampler):
         self.result.discard = self.discard
 
         log_evidence, log_evidence_err = compute_evidence(
-            sampler, self.log_likelihood_array, self.outdir,
-            self.label, self.discard, self.nburn,
-            self.thin, self.iteration,
+            sampler,
+            self.log_likelihood_array,
+            self.outdir,
+            self.label,
+            self.discard,
+            self.nburn,
+            self.thin,
+            self.iteration,
         )
         self.result.log_evidence = log_evidence
         self.result.log_evidence_err = log_evidence_err
@@ -620,21 +631,10 @@ class Ptemcee(MCMCSampler):
             seconds=np.sum(self.time_per_check)
         )
 
-        if self.pool:
-            self.pool.close()
+        self._close_pool()
 
         return self.result
 
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        logger.warning("Run terminated with signal {}".format(signum))
-        if getattr(self, "pool", None) or self.threads == 1:
-            self.write_current_state(plot=False)
-        if getattr(self, "pool", None):
-            logger.info("Closing pool")
-            self.pool.close()
-        logger.info("Exit on signal {}".format(self.exit_code))
-        sys.exit(self.exit_code)
-
     def write_current_state(self, plot=True):
         check_directory_exists_and_if_not_mkdir(self.outdir)
         checkpoint(
@@ -672,7 +672,7 @@ class Ptemcee(MCMCSampler):
                     self.discard,
                 )
             except Exception as e:
-                logger.info("Walkers plot failed with exception {}".format(e))
+                logger.info(f"Walkers plot failed with exception {e}")
 
             try:
                 # Generate the tau plot diagnostic if DEBUG
@@ -687,7 +687,7 @@ class Ptemcee(MCMCSampler):
                         self.convergence_inputs.autocorr_tau,
                     )
             except Exception as e:
-                logger.info("tau plot failed with exception {}".format(e))
+                logger.info(f"tau plot failed with exception {e}")
 
             try:
                 plot_mean_log_posterior(
@@ -696,7 +696,7 @@ class Ptemcee(MCMCSampler):
                     self.label,
                 )
             except Exception as e:
-                logger.info("mean_logl plot failed with exception {}".format(e))
+                logger.info(f"mean_logl plot failed with exception {e}")
 
 
 def get_minimum_stable_itertion(mean_array, frac, nsteps_min=10):
@@ -728,7 +728,7 @@ def check_iteration(
     mean_log_posterior,
     verbose=True,
 ):
-    """ Per-iteration logic to calculate the convergence check
+    """Per-iteration logic to calculate the convergence check
 
     Parameters
     ==========
@@ -780,8 +780,17 @@ def check_iteration(
     if np.isnan(tau) or np.isinf(tau):
         if verbose:
             print_progress(
-                iteration, sampler, time_per_check, np.nan, np.nan,
-                np.nan, np.nan, np.nan, False, convergence_inputs, Q,
+                iteration,
+                sampler,
+                time_per_check,
+                np.nan,
+                np.nan,
+                np.nan,
+                np.nan,
+                np.nan,
+                False,
+                convergence_inputs,
+                Q,
             )
         return False, np.nan, np.nan, np.nan, np.nan
 
@@ -796,45 +805,47 @@ def check_iteration(
 
     # Calculate convergence boolean
     converged = Q < ci.Q_tol and ci.nsamples < nsamples_effective
-    logger.debug("Convergence: Q<Q_tol={}, nsamples<nsamples_effective={}"
-                 .format(Q < ci.Q_tol, ci.nsamples < nsamples_effective))
+    logger.debug(
+        f"Convergence: Q<Q_tol={Q < ci.Q_tol}, "
+        f"nsamples<nsamples_effective={ci.nsamples < nsamples_effective}"
+    )
 
     GRAD_WINDOW_LENGTH = nwalkers + 1
     nsteps_to_check = ci.autocorr_tau * np.max([2 * GRAD_WINDOW_LENGTH, tau_int])
     lower_tau_index = np.max([0, len(tau_list) - nsteps_to_check])
-    check_taus = np.array(tau_list[lower_tau_index :])
+    check_taus = np.array(tau_list[lower_tau_index:])
     if not np.any(np.isnan(check_taus)) and check_taus.shape[0] > GRAD_WINDOW_LENGTH:
-        gradient_tau = get_max_gradient(
-            check_taus, axis=0, window_length=11)
+        gradient_tau = get_max_gradient(check_taus, axis=0, window_length=11)
 
         if gradient_tau < ci.gradient_tau:
             logger.debug(
-                "tau usable as {} < gradient_tau={}"
-                .format(gradient_tau, ci.gradient_tau)
+                f"tau usable as {gradient_tau} < gradient_tau={ci.gradient_tau}"
             )
             tau_usable = True
         else:
             logger.debug(
-                "tau not usable as {} > gradient_tau={}"
-                .format(gradient_tau, ci.gradient_tau)
+                f"tau not usable as {gradient_tau} > gradient_tau={ci.gradient_tau}"
             )
             tau_usable = False
 
         check_mean_log_posterior = mean_log_posterior[:, -nsteps_to_check:]
         gradient_mean_log_posterior = get_max_gradient(
-            check_mean_log_posterior, axis=1, window_length=GRAD_WINDOW_LENGTH,
-            smooth=True)
+            check_mean_log_posterior,
+            axis=1,
+            window_length=GRAD_WINDOW_LENGTH,
+            smooth=True,
+        )
 
         if gradient_mean_log_posterior < ci.gradient_mean_log_posterior:
             logger.debug(
-                "tau usable as {} < gradient_mean_log_posterior={}"
-                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+                f"tau usable as {gradient_mean_log_posterior} < "
+                f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}"
             )
             tau_usable *= True
         else:
             logger.debug(
-                "tau not usable as {} > gradient_mean_log_posterior={}"
-                .format(gradient_mean_log_posterior, ci.gradient_mean_log_posterior)
+                f"tau not usable as {gradient_mean_log_posterior} > "
+                f"gradient_mean_log_posterior={ci.gradient_mean_log_posterior}"
             )
             tau_usable = False
 
@@ -864,7 +875,7 @@ def check_iteration(
             gradient_mean_log_posterior,
             tau_usable,
             convergence_inputs,
-            Q
+            Q,
         )
     stop = converged and tau_usable
     return stop, nburn, thin, tau_int, nsamples_effective
@@ -872,13 +883,14 @@ def check_iteration(
 
 def get_max_gradient(x, axis=0, window_length=11, polyorder=2, smooth=False):
     from scipy.signal import savgol_filter
+
     if smooth:
-        x = savgol_filter(
-            x, axis=axis, window_length=window_length, polyorder=3
+        x = savgol_filter(x, axis=axis, window_length=window_length, polyorder=3)
+    return np.max(
+        savgol_filter(
+            x, axis=axis, window_length=window_length, polyorder=polyorder, deriv=1
         )
-    return np.max(savgol_filter(
-        x, axis=axis, window_length=window_length, polyorder=polyorder,
-        deriv=1))
+    )
 
 
 def get_Q_convergence(samples):
@@ -887,7 +899,7 @@ def get_Q_convergence(samples):
         W = np.mean(np.var(samples, axis=1), axis=0)
         per_walker_mean = np.mean(samples, axis=1)
         mean = np.mean(per_walker_mean, axis=0)
-        B = nsteps / (nwalkers - 1.) * np.sum((per_walker_mean - mean)**2, axis=0)
+        B = nsteps / (nwalkers - 1.0) * np.sum((per_walker_mean - mean) ** 2, axis=0)
         Vhat = (nsteps - 1) / nsteps * W + (nwalkers + 1) / (nwalkers * nsteps) * B
         Q_per_dim = np.sqrt(Vhat / W)
         return np.max(Q_per_dim)
@@ -910,16 +922,18 @@ def print_progress(
 ):
     # Setup acceptance string
     acceptance = sampler.acceptance_fraction[0, :]
-    acceptance_str = "{:1.2f}-{:1.2f}".format(np.min(acceptance), np.max(acceptance))
+    acceptance_str = f"{np.min(acceptance):1.2f}-{np.max(acceptance):1.2f}"
 
     # Setup tswap acceptance string
     tswap_acceptance_fraction = sampler.tswap_acceptance_fraction
-    tswap_acceptance_str = "{:1.2f}-{:1.2f}".format(
-        np.min(tswap_acceptance_fraction), np.max(tswap_acceptance_fraction)
-    )
+    tswap_acceptance_str = f"{np.min(tswap_acceptance_fraction):1.2f}-{np.max(tswap_acceptance_fraction):1.2f}"
 
     ave_time_per_check = np.mean(time_per_check[-3:])
-    time_left = (convergence_inputs.nsamples - nsamples_effective) * ave_time_per_check / samples_per_check
+    time_left = (
+        (convergence_inputs.nsamples - nsamples_effective)
+        * ave_time_per_check
+        / samples_per_check
+    )
     if time_left > 0:
         time_left = str(datetime.timedelta(seconds=int(time_left)))
     else:
@@ -927,46 +941,44 @@ def print_progress(
 
     sampling_time = datetime.timedelta(seconds=np.sum(time_per_check))
 
-    tau_str = "{}(+{:0.2f},+{:0.2f})".format(
-        tau_int, gradient_tau, gradient_mean_log_posterior
-    )
+    tau_str = f"{tau_int}(+{gradient_tau:0.2f},+{gradient_mean_log_posterior:0.2f})"
 
     if tau_usable:
-        tau_str = "={}".format(tau_str)
+        tau_str = f"={tau_str}"
     else:
-        tau_str = "!{}".format(tau_str)
+        tau_str = f"!{tau_str}"
 
-    Q_str = "{:0.2f}".format(Q)
+    Q_str = f"{Q:0.2f}"
 
-    evals_per_check = sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
+    evals_per_check = (
+        sampler.nwalkers * sampler.ntemps * convergence_inputs.niterations_per_check
+    )
 
-    ncalls = "{:1.1e}".format(
-        convergence_inputs.niterations_per_check * iteration * sampler.nwalkers * sampler.ntemps)
-    eval_timing = "{:1.2f}ms/ev".format(1e3 * ave_time_per_check / evals_per_check)
+    approximate_ncalls = (
+        convergence_inputs.niterations_per_check
+        * iteration
+        * sampler.nwalkers
+        * sampler.ntemps
+    )
+    ncalls = f"{approximate_ncalls:1.1e}"
+    eval_timing = f"{1000.0 * ave_time_per_check / evals_per_check:1.2f}ms/ev"
 
     try:
         print(
-            "{}|{}|nc:{}|a0:{}|swp:{}|n:{}<{}|t{}|q:{}|{}".format(
-                iteration,
-                str(sampling_time).split(".")[0],
-                ncalls,
-                acceptance_str,
-                tswap_acceptance_str,
-                nsamples_effective,
-                convergence_inputs.nsamples,
-                tau_str,
-                Q_str,
-                eval_timing,
-            ),
+            f"{iteration}|{str(sampling_time).split('.')[0]}|nc:{ncalls}|"
+            f"a0:{acceptance_str}|swp:{tswap_acceptance_str}|"
+            f"n:{nsamples_effective}<{convergence_inputs.nsamples}|t{tau_str}|"
+            f"q:{Q_str}|{eval_timing}",
             flush=True,
         )
     except OSError as e:
-        logger.debug("Failed to print iteration due to :{}".format(e))
+        logger.debug(f"Failed to print iteration due to :{e}")
 
 
 def calculate_tau_array(samples, search_parameter_keys, ci):
-    """ Compute ACT tau for 0-temperature chains """
+    """Compute ACT tau for 0-temperature chains"""
     import emcee
+
     nwalkers, nsteps, ndim = samples.shape
     tau_array = np.zeros((nwalkers, ndim)) + np.inf
     if nsteps > 1:
@@ -976,7 +988,8 @@ def calculate_tau_array(samples, search_parameter_keys, ci):
                     continue
                 try:
                     tau_array[ii, jj] = emcee.autocorr.integrated_time(
-                        samples[ii, :, jj], c=ci.autocorr_c, tol=0)[0]
+                        samples[ii, :, jj], c=ci.autocorr_c, tol=0
+                    )[0]
                 except emcee.autocorr.AutocorrError:
                     tau_array[ii, jj] = np.inf
     return tau_array
@@ -1004,21 +1017,24 @@ def checkpoint(
     time_per_check,
 ):
     import dill
+
     logger.info("Writing checkpoint and diagnostics")
     ndim = sampler.dim
 
     # Store the samples if possible
     if nsamples_effective > 0:
-        filename = "{}/{}_samples.txt".format(outdir, label)
-        samples = np.array(chain_array)[:, discard + nburn : iteration : thin, :].reshape(
-            (-1, ndim)
-        )
+        filename = f"{outdir}/{label}_samples.txt"
+        samples = np.array(chain_array)[
+            :, discard + nburn : iteration : thin, :
+        ].reshape((-1, ndim))
         df = pd.DataFrame(samples, columns=search_parameter_keys)
         df.to_csv(filename, index=False, header=True, sep=" ")
 
     # Pickle the resume artefacts
-    sampler_copy = copy.copy(sampler)
-    del sampler_copy.pool
+    pool = sampler.pool
+    sampler.pool = None
+    sampler_copy = copy.deepcopy(sampler)
+    sampler.pool = pool
 
     data = dict(
         iteration=iteration,
@@ -1040,10 +1056,10 @@ def checkpoint(
     logger.info("Finished writing checkpoint")
 
 
-def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
-                 discard=0):
-    """ Method to plot the trace of the walkers in an ensemble MCMC plot """
+def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label, discard=0):
+    """Method to plot the trace of the walkers in an ensemble MCMC plot"""
     import matplotlib.pyplot as plt
+
     nwalkers, nsteps, ndim = walkers.shape
     if np.isnan(nburn):
         nburn = nsteps
@@ -1051,51 +1067,65 @@ def plot_walkers(walkers, nburn, thin, parameter_labels, outdir, label,
         thin = 1
     idxs = np.arange(nsteps)
     fig, axes = plt.subplots(nrows=ndim, ncols=2, figsize=(8, 3 * ndim))
-    scatter_kwargs = dict(lw=0, marker="o", markersize=1, alpha=0.1,)
+    scatter_kwargs = dict(
+        lw=0,
+        marker="o",
+        markersize=1,
+        alpha=0.1,
+    )
 
     # Plot the fixed burn-in
     if discard > 0:
         for i, (ax, axh) in enumerate(axes):
             ax.plot(
-                idxs[: discard],
-                walkers[:, : discard, i].T,
+                idxs[:discard],
+                walkers[:, :discard, i].T,
                 color="gray",
-                **scatter_kwargs
+                **scatter_kwargs,
             )
 
     # Plot the burn-in
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[discard: discard + nburn + 1],
-            walkers[:, discard: discard + nburn + 1, i].T,
+            idxs[discard : discard + nburn + 1],
+            walkers[:, discard : discard + nburn + 1, i].T,
             color="C1",
-            **scatter_kwargs
+            **scatter_kwargs,
         )
 
     # Plot the thinned posterior samples
     for i, (ax, axh) in enumerate(axes):
         ax.plot(
-            idxs[discard + nburn::thin],
-            walkers[:, discard + nburn::thin, i].T,
+            idxs[discard + nburn :: thin],
+            walkers[:, discard + nburn :: thin, i].T,
             color="C0",
-            **scatter_kwargs
+            **scatter_kwargs,
+        )
+        axh.hist(
+            walkers[:, discard + nburn :: thin, i].reshape((-1)), bins=50, alpha=0.8
         )
-        axh.hist(walkers[:, discard + nburn::thin, i].reshape((-1)), bins=50, alpha=0.8)
 
     for i, (ax, axh) in enumerate(axes):
         axh.set_xlabel(parameter_labels[i])
         ax.set_ylabel(parameter_labels[i])
 
     fig.tight_layout()
-    filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
+    filename = f"{outdir}/{label}_checkpoint_trace.png"
     fig.savefig(filename)
     plt.close(fig)
 
 
 def plot_tau(
-    tau_list_n, tau_list, search_parameter_keys, outdir, label, tau, autocorr_tau,
+    tau_list_n,
+    tau_list,
+    search_parameter_keys,
+    outdir,
+    label,
+    tau,
+    autocorr_tau,
 ):
     import matplotlib.pyplot as plt
+
     fig, ax = plt.subplots()
     for i, key in enumerate(search_parameter_keys):
         ax.plot(tau_list_n, np.array(tau_list)[:, i], label=key)
@@ -1103,7 +1133,7 @@ def plot_tau(
     ax.set_ylabel(r"$\langle \tau \rangle$")
     ax.legend()
     fig.tight_layout()
-    fig.savefig("{}/{}_checkpoint_tau.png".format(outdir, label))
+    fig.savefig(f"{outdir}/{label}_checkpoint_tau.png")
     plt.close(fig)
 
 
@@ -1119,17 +1149,30 @@ def plot_mean_log_posterior(mean_log_posterior, outdir, label):
     fig, ax = plt.subplots()
     idxs = np.arange(nsteps)
     ax.plot(idxs, mean_log_posterior.T)
-    ax.set(xlabel="Iteration", ylabel=r"$\langle\mathrm{log-posterior}\rangle$",
-           ylim=(ymin, ymax))
+    ax.set(
+        xlabel="Iteration",
+        ylabel=r"$\langle\mathrm{log-posterior}\rangle$",
+        ylim=(ymin, ymax),
+    )
     fig.tight_layout()
-    fig.savefig("{}/{}_checkpoint_meanlogposterior.png".format(outdir, label))
+    fig.savefig(f"{outdir}/{label}_checkpoint_meanlogposterior.png")
     plt.close(fig)
 
 
-def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nburn, thin,
-                     iteration, make_plots=True):
-    """ Computes the evidence using thermodynamic integration """
+def compute_evidence(
+    sampler,
+    log_likelihood_array,
+    outdir,
+    label,
+    discard,
+    nburn,
+    thin,
+    iteration,
+    make_plots=True,
+):
+    """Computes the evidence using thermodynamic integration"""
     import matplotlib.pyplot as plt
+
     betas = sampler.betas
     # We compute the evidence without the burnin samples, but we do not thin
     lnlike = log_likelihood_array[:, :, discard + nburn : iteration]
@@ -1141,7 +1184,7 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur
     if any(np.isinf(mean_lnlikes)):
         logger.warning(
             "mean_lnlikes contains inf: recalculating without"
-            " the {} infs".format(len(betas[np.isinf(mean_lnlikes)]))
+            f" the {len(betas[np.isinf(mean_lnlikes)])} infs"
         )
         idxs = np.isinf(mean_lnlikes)
         mean_lnlikes = mean_lnlikes[~idxs]
@@ -1165,33 +1208,23 @@ def compute_evidence(sampler, log_likelihood_array, outdir, label, discard, nbur
 
         ax2.semilogx(min_betas, evidence, "-o")
         ax2.set_ylabel(
-            r"$\int_{\beta_{min}}^{\beta=1}" + r"\langle \log(\mathcal{L})\rangle d\beta$",
+            r"$\int_{\beta_{min}}^{\beta=1}"
+            + r"\langle \log(\mathcal{L})\rangle d\beta$",
             size=16,
         )
         ax2.set_xlabel(r"$\beta_{min}$")
         plt.tight_layout()
-        fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
+        fig.savefig(f"{outdir}/{label}_beta_lnl.png")
         plt.close(fig)
 
     return lnZ, lnZerr
 
 
 def do_nothing_function():
-    """ This is a do-nothing function, we overwrite the likelihood and prior elsewhere """
+    """This is a do-nothing function, we overwrite the likelihood and prior elsewhere"""
     pass
 
 
-likelihood = None
-priors = None
-
-
-def init(likelihood_in, priors_in):
-    global likelihood
-    global priors
-    likelihood = likelihood_in
-    priors = priors_in
-
-
 class LikePriorEvaluator(object):
     """
     This class is copied and modified from ptemcee.LikePriorEvaluator, see
@@ -1203,38 +1236,43 @@ class LikePriorEvaluator(object):
 
     """
 
-    def __init__(self, search_parameter_keys, use_ratio=False):
-        self.search_parameter_keys = search_parameter_keys
-        self.use_ratio = use_ratio
+    def __init__(self):
         self.periodic_set = False
 
     def _setup_periodic(self):
+        priors = _sampling_convenience_dump.priors
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
         self._periodic = [
-            priors[key].boundary == "periodic" for key in self.search_parameter_keys
+            priors[key].boundary == "periodic" for key in search_parameter_keys
         ]
         priors.sample()
-        self._minima = np.array([
-            priors[key].minimum for key in self.search_parameter_keys
-        ])
-        self._range = np.array([
-            priors[key].maximum for key in self.search_parameter_keys
-        ]) - self._minima
+        self._minima = np.array([priors[key].minimum for key in search_parameter_keys])
+        self._range = (
+            np.array([priors[key].maximum for key in search_parameter_keys])
+            - self._minima
+        )
         self.periodic_set = True
 
     def _wrap_periodic(self, array):
         if not self.periodic_set:
             self._setup_periodic()
-        array[self._periodic] = np.mod(
-            array[self._periodic] - self._minima[self._periodic],
-            self._range[self._periodic]
-        ) + self._minima[self._periodic]
+        array[self._periodic] = (
+            np.mod(
+                array[self._periodic] - self._minima[self._periodic],
+                self._range[self._periodic],
+            )
+            + self._minima[self._periodic]
+        )
         return array
 
     def logl(self, v_array):
-        parameters = {key: v for key, v in zip(self.search_parameter_keys, v_array)}
+        priors = _sampling_convenience_dump.priors
+        likelihood = _sampling_convenience_dump.likelihood
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
+        parameters = {key: v for key, v in zip(search_parameter_keys, v_array)}
         if priors.evaluate_constraints(parameters) > 0:
             likelihood.parameters.update(parameters)
-            if self.use_ratio:
+            if _sampling_convenience_dump.use_ratio:
                 return likelihood.log_likelihood() - likelihood.noise_log_likelihood()
             else:
                 return likelihood.log_likelihood()
@@ -1242,9 +1280,15 @@ class LikePriorEvaluator(object):
             return np.nan_to_num(-np.inf)
 
     def logp(self, v_array):
-        params = {key: t for key, t in zip(self.search_parameter_keys, v_array)}
+        priors = _sampling_convenience_dump.priors
+        search_parameter_keys = _sampling_convenience_dump.search_parameter_keys
+        params = {key: t for key, t in zip(search_parameter_keys, v_array)}
         return priors.ln_prob(params)
 
+    def call_emcee(self, theta):
+        ll, lp = self.__call__(theta)
+        return ll + lp, [ll, lp]
+
     def __call__(self, x):
         lp = self.logp(x)
         if np.isnan(lp):
diff --git a/bilby/core/sampler/ptmcmc.py b/bilby/core/sampler/ptmcmc.py
index 49b86d7392ff4d21af698fe7b7034b25eaa7ac22..6b9c3c96eb83a81486df1e38dd5948b668cfb358 100644
--- a/bilby/core/sampler/ptmcmc.py
+++ b/bilby/core/sampler/ptmcmc.py
@@ -1,11 +1,10 @@
-
 import glob
 import shutil
 
 import numpy as np
 
-from .base_sampler import MCMCSampler, SamplerNotInstalledError
 from ..utils import logger
+from .base_sampler import MCMCSampler, SamplerNotInstalledError, signal_wrapper
 
 
 class PTMCMCSampler(MCMCSampler):
@@ -42,29 +41,66 @@ class PTMCMCSampler(MCMCSampler):
 
     """
 
-    default_kwargs = {'p0': None, 'Niter': 2 * 10 ** 4 + 1, 'neff': 10 ** 4,
-                      'burn': 5 * 10 ** 3, 'verbose': True,
-                      'ladder': None, 'Tmin': 1, 'Tmax': None, 'Tskip': 100,
-                      'isave': 1000, 'thin': 1, 'covUpdate': 1000,
-                      'SCAMweight': 1, 'AMweight': 1, 'DEweight': 1,
-                      'HMCweight': 0, 'MALAweight': 0, 'NUTSweight': 0,
-                      'HMCstepsize': 0.1, 'HMCsteps': 300,
-                      'groups': None, 'custom_proposals': None,
-                      'loglargs': {}, 'loglkwargs': {}, 'logpargs': {},
-                      'logpkwargs': {}, 'logl_grad': None, 'logp_grad': None,
-                      'outDir': None}
-
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False, skip_import_verification=False,
-                 pos0=None, burn_in_fraction=0.25, **kwargs):
-
-        super(PTMCMCSampler, self).__init__(likelihood=likelihood, priors=priors,
-                                            outdir=outdir, label=label, use_ratio=use_ratio,
-                                            plot=plot,
-                                            skip_import_verification=skip_import_verification,
-                                            **kwargs)
-
-        self.p0 = self.get_random_draw_from_prior()
+    default_kwargs = {
+        "p0": None,
+        "Niter": 2 * 10**4 + 1,
+        "neff": 10**4,
+        "burn": 5 * 10**3,
+        "verbose": True,
+        "ladder": None,
+        "Tmin": 1,
+        "Tmax": None,
+        "Tskip": 100,
+        "isave": 1000,
+        "thin": 1,
+        "covUpdate": 1000,
+        "SCAMweight": 1,
+        "AMweight": 1,
+        "DEweight": 1,
+        "HMCweight": 0,
+        "MALAweight": 0,
+        "NUTSweight": 0,
+        "HMCstepsize": 0.1,
+        "HMCsteps": 300,
+        "groups": None,
+        "custom_proposals": None,
+        "loglargs": {},
+        "loglkwargs": {},
+        "logpargs": {},
+        "logpkwargs": {},
+        "logl_grad": None,
+        "logp_grad": None,
+        "outDir": None,
+    }
+    hard_exit = True
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        **kwargs,
+    ):
+
+        super(PTMCMCSampler, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+
+        if self.kwargs["p0"] is None:
+            self.p0 = self.get_random_draw_from_prior()
+        else:
+            self.p0 = self.kwargs["p0"]
         self.likelihood = likelihood
         self.priors = priors
 
@@ -73,88 +109,102 @@ class PTMCMCSampler(MCMCSampler):
         # which forces `__name__.lower()
         external_sampler_name = self.__class__.__name__
         try:
-            self.external_sampler = __import__(external_sampler_name)
+            __import__(external_sampler_name)
         except (ImportError, SystemExit):
             raise SamplerNotInstalledError(
-                "Sampler {} is not installed on this system".format(external_sampler_name))
+                f"Sampler {external_sampler_name} is not installed on this system"
+            )
 
     def _translate_kwargs(self, kwargs):
-        if 'Niter' not in kwargs:
+        if "Niter" not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['Niter'] = kwargs.pop(equiv)
-        if 'burn' not in kwargs:
+                    kwargs["Niter"] = kwargs.pop(equiv)
+        if "burn" not in kwargs:
             for equiv in self.nburn_equiv_kwargs:
                 if equiv in kwargs:
-                    kwargs['burn'] = kwargs.pop(equiv)
+                    kwargs["burn"] = kwargs.pop(equiv)
 
     @property
     def custom_proposals(self):
-        return self.kwargs['custom_proposals']
+        return self.kwargs["custom_proposals"]
 
     @property
     def sampler_init_kwargs(self):
-        keys = ['groups',
-                'loglargs',
-                'logp_grad',
-                'logpkwargs',
-                'loglkwargs',
-                'logl_grad',
-                'logpargs',
-                'outDir',
-                'verbose']
+        keys = [
+            "groups",
+            "loglargs",
+            "logp_grad",
+            "logpkwargs",
+            "loglkwargs",
+            "logl_grad",
+            "logpargs",
+            "outDir",
+            "verbose",
+        ]
         init_kwargs = {key: self.kwargs[key] for key in keys}
-        if init_kwargs['outDir'] is None:
-            init_kwargs['outDir'] = '{}/ptmcmc_temp_{}/'.format(self.outdir, self.label)
+        if init_kwargs["outDir"] is None:
+            init_kwargs["outDir"] = f"{self.outdir}/ptmcmc_temp_{self.label}/"
         return init_kwargs
 
     @property
     def sampler_function_kwargs(self):
-        keys = ['Niter',
-                'neff',
-                'Tmin',
-                'HMCweight',
-                'covUpdate',
-                'SCAMweight',
-                'ladder',
-                'burn',
-                'NUTSweight',
-                'AMweight',
-                'MALAweight',
-                'thin',
-                'HMCstepsize',
-                'isave',
-                'Tskip',
-                'HMCsteps',
-                'Tmax',
-                'DEweight']
+        keys = [
+            "Niter",
+            "neff",
+            "Tmin",
+            "HMCweight",
+            "covUpdate",
+            "SCAMweight",
+            "ladder",
+            "burn",
+            "NUTSweight",
+            "AMweight",
+            "MALAweight",
+            "thin",
+            "HMCstepsize",
+            "isave",
+            "Tskip",
+            "HMCsteps",
+            "Tmax",
+            "DEweight",
+        ]
         sampler_kwargs = {key: self.kwargs[key] for key in keys}
         return sampler_kwargs
 
     @staticmethod
     def _import_external_sampler():
         from PTMCMCSampler import PTMCMCSampler
+
         return PTMCMCSampler
 
+    @signal_wrapper
     def run_sampler(self):
         PTMCMCSampler = self._import_external_sampler()
-        sampler = PTMCMCSampler.PTSampler(ndim=self.ndim, logp=self.log_prior,
-                                          logl=self.log_likelihood, cov=np.eye(self.ndim),
-                                          **self.sampler_init_kwargs)
+        sampler = PTMCMCSampler.PTSampler(
+            ndim=self.ndim,
+            logp=self.log_prior,
+            logl=self.log_likelihood,
+            cov=np.eye(self.ndim),
+            **self.sampler_init_kwargs,
+        )
         if self.custom_proposals is not None:
             for proposal in self.custom_proposals:
-                logger.info('Adding {} to proposals with weight {}'.format(
-                    proposal, self.custom_proposals[proposal][1]))
-                sampler.addProposalToCycle(self.custom_proposals[proposal][0],
-                                           self.custom_proposals[proposal][1])
+                logger.info(
+                    f"Adding {proposal} to proposals with weight {self.custom_proposals[proposal][1]}"
+                )
+                sampler.addProposalToCycle(
+                    self.custom_proposals[proposal][0],
+                    self.custom_proposals[proposal][1],
+                )
         sampler.sample(p0=self.p0, **self.sampler_function_kwargs)
         samples, meta, loglike = self.__read_in_data()
 
         self.calc_likelihood_count()
-        self.result.nburn = self.sampler_function_kwargs['burn']
-        self.result.samples = samples[self.result.nburn:]
-        self.meta_data['sampler_meta'] = meta
-        self.result.log_likelihood_evaluations = loglike[self.result.nburn:]
+        self.result.nburn = self.sampler_function_kwargs["burn"]
+        self.result.samples = samples[self.result.nburn :]
+        self.meta_data["sampler_meta"] = meta
+        self.result.log_likelihood_evaluations = loglike[self.result.nburn :]
         self.result.sampler_output = np.nan
         self.result.walkers = np.nan
         self.result.log_evidence = np.nan
@@ -162,30 +212,34 @@ class PTMCMCSampler(MCMCSampler):
         return self.result
 
     def __read_in_data(self):
-        """ Read the data stored by PTMCMC to disk """
-        temp_outDir = self.sampler_init_kwargs['outDir']
+        """Read the data stored by PTMCMC to disk"""
+        temp_outDir = self.sampler_init_kwargs["outDir"]
         try:
-            data = np.loadtxt('{}chain_1.txt'.format(temp_outDir))
+            data = np.loadtxt(f"{temp_outDir}chain_1.txt")
         except OSError:
-            data = np.loadtxt('{}chain_1.0.txt'.format(temp_outDir))
-        jumpfiles = glob.glob('{}/*jump.txt'.format(temp_outDir))
+            data = np.loadtxt(f"{temp_outDir}chain_1.0.txt")
+        jumpfiles = glob.glob(f"{temp_outDir}/*jump.txt")
         jumps = map(np.loadtxt, jumpfiles)
         samples = data[:, :-4]
         loglike = data[:, -3]
 
         jump_accept = {}
         for ct, j in enumerate(jumps):
-            label = jumpfiles[ct].split('/')[-1].split('_jump.txt')[0]
+            label = jumpfiles[ct].split("/")[-1].split("_jump.txt")[0]
             jump_accept[label] = j
-        PT_swap = {'swap_accept': data[:, -1]}
-        tot_accept = {'tot_accept': data[:, -2]}
-        log_post = {'log_post': data[:, -4]}
+        PT_swap = {"swap_accept": data[:, -1]}
+        tot_accept = {"tot_accept": data[:, -2]}
+        log_post = {"log_post": data[:, -4]}
         meta = {}
-        meta['tot_accept'] = tot_accept
-        meta['PT_swap'] = PT_swap
-        meta['proposals'] = jump_accept
-        meta['log_post'] = log_post
+        meta["tot_accept"] = tot_accept
+        meta["PT_swap"] = PT_swap
+        meta["proposals"] = jump_accept
+        meta["log_post"] = log_post
 
         shutil.rmtree(temp_outDir)
 
         return samples, meta, loglike
+
+    def write_current_state(self):
+        """TODO: implement a checkpointing method"""
+        pass
diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py
index 4ff4b232adcfb3432a1f8ce2973fb24fc4778d7b..1c6a2790a74c61001577bc8606ee0834a1b80455 100644
--- a/bilby/core/sampler/pymc3.py
+++ b/bilby/core/sampler/pymc3.py
@@ -2,16 +2,20 @@ from distutils.version import StrictVersion
 
 import numpy as np
 
+from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
+from ..likelihood import (
+    ExponentialLikelihood,
+    GaussianLikelihood,
+    PoissonLikelihood,
+    StudentTLikelihood,
+)
+from ..prior import Cosine, DeltaFunction, MultivariateGaussian, PowerLaw, Sine
 from ..utils import derivatives, infer_args_from_method
-from ..prior import DeltaFunction, Sine, Cosine, PowerLaw, MultivariateGaussian
 from .base_sampler import MCMCSampler
-from ..likelihood import GaussianLikelihood, PoissonLikelihood, ExponentialLikelihood, \
-    StudentTLikelihood
-from ...gw.likelihood import BasicGravitationalWaveTransient, GravitationalWaveTransient
 
 
 class Pymc3(MCMCSampler):
-    """ bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
+    """bilby wrapper of the PyMC3 sampler (https://docs.pymc.io/)
 
     All keyword arguments (i.e., the kwargs) passed to `run_sampler` will be
     propapated to `pymc3.sample` where appropriate, see documentation for that
@@ -51,39 +55,77 @@ class Pymc3(MCMCSampler):
     """
 
     default_kwargs = dict(
-        draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0,
-        chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None,
-        discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None,
+        draws=500,
+        step=None,
+        init="auto",
+        n_init=200000,
+        start=None,
+        trace=None,
+        chain_idx=0,
+        chains=2,
+        cores=1,
+        tune=500,
+        progressbar=True,
+        model=None,
+        random_seed=None,
+        discard_tuned_samples=True,
+        compute_convergence_checks=True,
+        nuts_kwargs=None,
         step_kwargs=None,
     )
 
     default_nuts_kwargs = dict(
-        target_accept=None, max_treedepth=None, step_scale=None, Emax=None,
-        gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None,
-        scaling=None, is_cov=None, potential=None,
+        target_accept=None,
+        max_treedepth=None,
+        step_scale=None,
+        Emax=None,
+        gamma=None,
+        k=None,
+        t0=None,
+        adapt_step_size=None,
+        early_max_treedepth=None,
+        scaling=None,
+        is_cov=None,
+        potential=None,
     )
 
     default_kwargs.update(default_nuts_kwargs)
 
-    def __init__(self, likelihood, priors, outdir='outdir', label='label',
-                 use_ratio=False, plot=False,
-                 skip_import_verification=False, **kwargs):
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        plot=False,
+        skip_import_verification=False,
+        **kwargs,
+    ):
         # add default step kwargs
         _, STEP_METHODS, _ = self._import_external_sampler()
         self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
         self.default_kwargs.update(self.default_step_kwargs)
 
-        super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
-                                    use_ratio=use_ratio, plot=plot,
-                                    skip_import_verification=skip_import_verification, **kwargs)
-        self.draws = self._kwargs['draws']
-        self.chains = self._kwargs['chains']
+        super(Pymc3, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            plot=plot,
+            skip_import_verification=skip_import_verification,
+            **kwargs,
+        )
+        self.draws = self._kwargs["draws"]
+        self.chains = self._kwargs["chains"]
 
     @staticmethod
     def _import_external_sampler():
         import pymc3
         from pymc3.sampling import STEP_METHODS
         from pymc3.theanof import floatX
+
         return pymc3, STEP_METHODS, floatX
 
     @staticmethod
@@ -91,6 +133,7 @@ class Pymc3(MCMCSampler):
         import theano  # noqa
         import theano.tensor as tt
         from theano.compile.ops import as_op  # noqa
+
         return theano, tt, as_op
 
     def _verify_parameters(self):
@@ -116,77 +159,79 @@ class Pymc3(MCMCSampler):
         self.prior_map = prior_map
 
         # predefined PyMC3 distributions
-        prior_map['Gaussian'] = {
-            'pymc3': 'Normal',
-            'argmap': {'mu': 'mu', 'sigma': 'sd'}}
-        prior_map['TruncatedGaussian'] = {
-            'pymc3': 'TruncatedNormal',
-            'argmap': {'mu': 'mu',
-                       'sigma': 'sd',
-                       'minimum': 'lower',
-                       'maximum': 'upper'}}
-        prior_map['HalfGaussian'] = {
-            'pymc3': 'HalfNormal',
-            'argmap': {'sigma': 'sd'}}
-        prior_map['Uniform'] = {
-            'pymc3': 'Uniform',
-            'argmap': {'minimum': 'lower',
-                       'maximum': 'upper'}}
-        prior_map['LogNormal'] = {
-            'pymc3': 'Lognormal',
-            'argmap': {'mu': 'mu',
-                       'sigma': 'sd'}}
-        prior_map['Exponential'] = {
-            'pymc3': 'Exponential',
-            'argmap': {'mu': 'lam'},
-            'argtransform': {'mu': lambda mu: 1. / mu}}
-        prior_map['StudentT'] = {
-            'pymc3': 'StudentT',
-            'argmap': {'df': 'nu',
-                       'mu': 'mu',
-                       'scale': 'sd'}}
-        prior_map['Beta'] = {
-            'pymc3': 'Beta',
-            'argmap': {'alpha': 'alpha',
-                       'beta': 'beta'}}
-        prior_map['Logistic'] = {
-            'pymc3': 'Logistic',
-            'argmap': {'mu': 'mu',
-                       'scale': 's'}}
-        prior_map['Cauchy'] = {
-            'pymc3': 'Cauchy',
-            'argmap': {'alpha': 'alpha',
-                       'beta': 'beta'}}
-        prior_map['Gamma'] = {
-            'pymc3': 'Gamma',
-            'argmap': {'k': 'alpha',
-                       'theta': 'beta'},
-            'argtransform': {'theta': lambda theta: 1. / theta}}
-        prior_map['ChiSquared'] = {
-            'pymc3': 'ChiSquared',
-            'argmap': {'nu': 'nu'}}
-        prior_map['Interped'] = {
-            'pymc3': 'Interpolated',
-            'argmap': {'xx': 'x_points',
-                       'yy': 'pdf_points'}}
-        prior_map['Normal'] = prior_map['Gaussian']
-        prior_map['TruncatedNormal'] = prior_map['TruncatedGaussian']
-        prior_map['HalfNormal'] = prior_map['HalfGaussian']
-        prior_map['LogGaussian'] = prior_map['LogNormal']
-        prior_map['Lorentzian'] = prior_map['Cauchy']
-        prior_map['FromFile'] = prior_map['Interped']
+        prior_map["Gaussian"] = {
+            "pymc3": "Normal",
+            "argmap": {"mu": "mu", "sigma": "sd"},
+        }
+        prior_map["TruncatedGaussian"] = {
+            "pymc3": "TruncatedNormal",
+            "argmap": {
+                "mu": "mu",
+                "sigma": "sd",
+                "minimum": "lower",
+                "maximum": "upper",
+            },
+        }
+        prior_map["HalfGaussian"] = {"pymc3": "HalfNormal", "argmap": {"sigma": "sd"}}
+        prior_map["Uniform"] = {
+            "pymc3": "Uniform",
+            "argmap": {"minimum": "lower", "maximum": "upper"},
+        }
+        prior_map["LogNormal"] = {
+            "pymc3": "Lognormal",
+            "argmap": {"mu": "mu", "sigma": "sd"},
+        }
+        prior_map["Exponential"] = {
+            "pymc3": "Exponential",
+            "argmap": {"mu": "lam"},
+            "argtransform": {"mu": lambda mu: 1.0 / mu},
+        }
+        prior_map["StudentT"] = {
+            "pymc3": "StudentT",
+            "argmap": {"df": "nu", "mu": "mu", "scale": "sd"},
+        }
+        prior_map["Beta"] = {
+            "pymc3": "Beta",
+            "argmap": {"alpha": "alpha", "beta": "beta"},
+        }
+        prior_map["Logistic"] = {
+            "pymc3": "Logistic",
+            "argmap": {"mu": "mu", "scale": "s"},
+        }
+        prior_map["Cauchy"] = {
+            "pymc3": "Cauchy",
+            "argmap": {"alpha": "alpha", "beta": "beta"},
+        }
+        prior_map["Gamma"] = {
+            "pymc3": "Gamma",
+            "argmap": {"k": "alpha", "theta": "beta"},
+            "argtransform": {"theta": lambda theta: 1.0 / theta},
+        }
+        prior_map["ChiSquared"] = {"pymc3": "ChiSquared", "argmap": {"nu": "nu"}}
+        prior_map["Interped"] = {
+            "pymc3": "Interpolated",
+            "argmap": {"xx": "x_points", "yy": "pdf_points"},
+        }
+        prior_map["Normal"] = prior_map["Gaussian"]
+        prior_map["TruncatedNormal"] = prior_map["TruncatedGaussian"]
+        prior_map["HalfNormal"] = prior_map["HalfGaussian"]
+        prior_map["LogGaussian"] = prior_map["LogNormal"]
+        prior_map["Lorentzian"] = prior_map["Cauchy"]
+        prior_map["FromFile"] = prior_map["Interped"]
 
         # GW specific priors
-        prior_map['UniformComovingVolume'] = prior_map['Interped']
+        prior_map["UniformComovingVolume"] = prior_map["Interped"]
 
         # internally defined mappings for bilby priors
-        prior_map['DeltaFunction'] = {'internal': self._deltafunction_prior}
-        prior_map['Sine'] = {'internal': self._sine_prior}
-        prior_map['Cosine'] = {'internal': self._cosine_prior}
-        prior_map['PowerLaw'] = {'internal': self._powerlaw_prior}
-        prior_map['LogUniform'] = {'internal': self._powerlaw_prior}
-        prior_map['MultivariateGaussian'] = {'internal': self._multivariate_normal_prior}
-        prior_map['MultivariateNormal'] = {'internal': self._multivariate_normal_prior}
+        prior_map["DeltaFunction"] = {"internal": self._deltafunction_prior}
+        prior_map["Sine"] = {"internal": self._sine_prior}
+        prior_map["Cosine"] = {"internal": self._cosine_prior}
+        prior_map["PowerLaw"] = {"internal": self._powerlaw_prior}
+        prior_map["LogUniform"] = {"internal": self._powerlaw_prior}
+        prior_map["MultivariateGaussian"] = {
+            "internal": self._multivariate_normal_prior
+        }
+        prior_map["MultivariateNormal"] = {"internal": self._multivariate_normal_prior}
 
     def _deltafunction_prior(self, key, **kwargs):
         """
@@ -197,7 +242,7 @@ class Pymc3(MCMCSampler):
         if isinstance(self.priors[key], DeltaFunction):
             return self.priors[key].peak
         else:
-            raise ValueError("Prior for '{}' is not a DeltaFunction".format(key))
+            raise ValueError(f"Prior for '{key}' is not a DeltaFunction")
 
     def _sine_prior(self, key):
         """
@@ -210,20 +255,22 @@ class Pymc3(MCMCSampler):
         if isinstance(self.priors[key], Sine):
 
             class Pymc3Sine(pymc3.Continuous):
-                def __init__(self, lower=0., upper=np.pi):
+                def __init__(self, lower=0.0, upper=np.pi):
                     if lower >= upper:
                         raise ValueError("Lower bound is above upper bound!")
 
                     # set the mode
                     self.lower = lower = tt.as_tensor_variable(floatX(lower))
                     self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                    self.norm = (tt.cos(lower) - tt.cos(upper))
-                    self.mean = \
-                        (tt.sin(upper) + lower * tt.cos(lower) -
-                         tt.sin(lower) - upper * tt.cos(upper)) / self.norm
+                    self.norm = tt.cos(lower) - tt.cos(upper)
+                    self.mean = (
+                        tt.sin(upper)
+                        + lower * tt.cos(lower)
+                        - tt.sin(lower)
+                        - upper * tt.cos(upper)
+                    ) / self.norm
 
-                    transform = pymc3.distributions.transforms.interval(lower,
-                                                                        upper)
+                    transform = pymc3.distributions.transforms.interval(lower, upper)
 
                     super(Pymc3Sine, self).__init__(transform=transform)
 
@@ -232,12 +279,15 @@ class Pymc3(MCMCSampler):
                     lower = self.lower
                     return pymc3.distributions.dist_math.bound(
                         tt.log(tt.sin(value) / self.norm),
-                        lower <= value, value <= upper)
+                        lower <= value,
+                        value <= upper,
+                    )
 
-            return Pymc3Sine(key, lower=self.priors[key].minimum,
-                             upper=self.priors[key].maximum)
+            return Pymc3Sine(
+                key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
+            )
         else:
-            raise ValueError("Prior for '{}' is not a Sine".format(key))
+            raise ValueError(f"Prior for '{key}' is not a Sine")
 
     def _cosine_prior(self, key):
         """
@@ -250,19 +300,21 @@ class Pymc3(MCMCSampler):
         if isinstance(self.priors[key], Cosine):
 
             class Pymc3Cosine(pymc3.Continuous):
-                def __init__(self, lower=-np.pi / 2., upper=np.pi / 2.):
+                def __init__(self, lower=-np.pi / 2.0, upper=np.pi / 2.0):
                     if lower >= upper:
                         raise ValueError("Lower bound is above upper bound!")
 
                     self.lower = lower = tt.as_tensor_variable(floatX(lower))
                     self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                    self.norm = (tt.sin(upper) - tt.sin(lower))
-                    self.mean = \
-                        (upper * tt.sin(upper) + tt.cos(upper) -
-                         lower * tt.sin(lower) - tt.cos(lower)) / self.norm
+                    self.norm = tt.sin(upper) - tt.sin(lower)
+                    self.mean = (
+                        upper * tt.sin(upper)
+                        + tt.cos(upper)
+                        - lower * tt.sin(lower)
+                        - tt.cos(lower)
+                    ) / self.norm
 
-                    transform = pymc3.distributions.transforms.interval(lower,
-                                                                        upper)
+                    transform = pymc3.distributions.transforms.interval(lower, upper)
 
                     super(Pymc3Cosine, self).__init__(transform=transform)
 
@@ -271,12 +323,15 @@ class Pymc3(MCMCSampler):
                     lower = self.lower
                     return pymc3.distributions.dist_math.bound(
                         tt.log(tt.cos(value) / self.norm),
-                        lower <= value, value <= upper)
+                        lower <= value,
+                        value <= upper,
+                    )
 
-            return Pymc3Cosine(key, lower=self.priors[key].minimum,
-                               upper=self.priors[key].maximum)
+            return Pymc3Cosine(
+                key, lower=self.priors[key].minimum, upper=self.priors[key].maximum
+            )
         else:
-            raise ValueError("Prior for '{}' is not a Cosine".format(key))
+            raise ValueError(f"Prior for '{key}' is not a Cosine")
 
     def _powerlaw_prior(self, key):
         """
@@ -289,17 +344,18 @@ class Pymc3(MCMCSampler):
         if isinstance(self.priors[key], PowerLaw):
 
             # check power law is set
-            if not hasattr(self.priors[key], 'alpha'):
+            if not hasattr(self.priors[key], "alpha"):
                 raise AttributeError("No 'alpha' attribute set for PowerLaw prior")
 
-            if self.priors[key].alpha < -1.:
+            if self.priors[key].alpha < -1.0:
                 # use Pareto distribution
-                palpha = -(1. + self.priors[key].alpha)
+                palpha = -(1.0 + self.priors[key].alpha)
 
-                return pymc3.Bound(
-                    pymc3.Pareto, upper=self.priors[key].minimum)(
-                    key, alpha=palpha, m=self.priors[key].maximum)
+                return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)(
+                    key, alpha=palpha, m=self.priors[key].maximum
+                )
             else:
+
                 class Pymc3PowerLaw(pymc3.Continuous):
                     def __init__(self, lower, upper, alpha, testval=1):
                         falpha = alpha
@@ -308,17 +364,21 @@ class Pymc3(MCMCSampler):
                         self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
 
                         if falpha == -1:
-                            self.norm = 1. / (tt.log(self.upper / self.lower))
+                            self.norm = 1.0 / (tt.log(self.upper / self.lower))
                         else:
-                            beta = (1. + self.alpha)
-                            self.norm = 1. / (beta * (tt.pow(self.upper, beta) -
-                                                      tt.pow(self.lower, beta)))
+                            beta = 1.0 + self.alpha
+                            self.norm = 1.0 / (
+                                beta
+                                * (tt.pow(self.upper, beta) - tt.pow(self.lower, beta))
+                            )
 
                         transform = pymc3.distributions.transforms.interval(
-                            lower, upper)
+                            lower, upper
+                        )
 
                         super(Pymc3PowerLaw, self).__init__(
-                            transform=transform, testval=testval)
+                            transform=transform, testval=testval
+                        )
 
                     def logp(self, value):
                         upper = self.upper
@@ -327,13 +387,18 @@ class Pymc3(MCMCSampler):
 
                         return pymc3.distributions.dist_math.bound(
                             alpha * tt.log(value) + tt.log(self.norm),
-                            lower <= value, value <= upper)
-
-                return Pymc3PowerLaw(key, lower=self.priors[key].minimum,
-                                     upper=self.priors[key].maximum,
-                                     alpha=self.priors[key].alpha)
+                            lower <= value,
+                            value <= upper,
+                        )
+
+                return Pymc3PowerLaw(
+                    key,
+                    lower=self.priors[key].minimum,
+                    upper=self.priors[key].maximum,
+                    alpha=self.priors[key].alpha,
+                )
         else:
-            raise ValueError("Prior for '{}' is not a Power Law".format(key))
+            raise ValueError(f"Prior for '{key}' is not a Power Law")
 
     def _multivariate_normal_prior(self, key):
         """
@@ -359,14 +424,14 @@ class Pymc3(MCMCSampler):
                 testvals = []
                 for bound in mvg.bounds.values():
                     if np.isinf(bound[0]) and np.isinf(bound[1]):
-                        testvals.append(0.)
+                        testvals.append(0.0)
                     elif np.isinf(bound[0]):
-                        testvals.append(bound[1] - 1.)
+                        testvals.append(bound[1] - 1.0)
                     elif np.isinf(bound[1]):
-                        testvals.append(bound[0] + 1.)
+                        testvals.append(bound[0] + 1.0)
                     else:
                         # half-way between the two bounds
-                        testvals.append(bound[0] + (bound[1] - bound[0]) / 2.)
+                        testvals.append(bound[0] + (bound[1] - bound[0]) / 2.0)
 
                 # if bounds are at +/-infinity set to 100 sigmas as infinities
                 # cause problems for the Bound class
@@ -375,44 +440,54 @@ class Pymc3(MCMCSampler):
                 maxsigma = np.max(mvg.sigmas, axis=0)
                 for i in range(len(mvpars)):
                     if np.isinf(lower[i]):
-                        lower[i] = minmu[i] - 100. * maxsigma[i]
+                        lower[i] = minmu[i] - 100.0 * maxsigma[i]
                     if np.isinf(upper[i]):
-                        upper[i] = maxmu[i] + 100. * maxsigma[i]
+                        upper[i] = maxmu[i] + 100.0 * maxsigma[i]
 
                 # create a bounded MultivariateNormal distribution
                 BoundedMvN = pymc3.Bound(pymc3.MvNormal, lower=lower, upper=upper)
 
                 comp_dists = []  # list of any component modes
                 for i in range(mvg.nmodes):
-                    comp_dists.append(BoundedMvN('comp{}'.format(i), mu=mvg.mus[i],
-                                                 cov=mvg.covs[i],
-                                                 shape=len(mvpars)).distribution)
+                    comp_dists.append(
+                        BoundedMvN(
+                            f"comp{i}",
+                            mu=mvg.mus[i],
+                            cov=mvg.covs[i],
+                            shape=len(mvpars),
+                        ).distribution
+                    )
 
                 # create a Mixture model
-                setname = 'mixture{}'.format(self.multivariate_normal_num_sets)
-                mix = pymc3.Mixture(setname, w=mvg.weights, comp_dists=comp_dists,
-                                    shape=len(mvpars), testval=testvals)
+                setname = f"mixture{self.multivariate_normal_num_sets}"
+                mix = pymc3.Mixture(
+                    setname,
+                    w=mvg.weights,
+                    comp_dists=comp_dists,
+                    shape=len(mvpars),
+                    testval=testvals,
+                )
 
                 for i, p in enumerate(mvpars):
                     self.multivariate_normal_sets[p] = {}
-                    self.multivariate_normal_sets[p]['prior'] = mix[i]
-                    self.multivariate_normal_sets[p]['set'] = setname
-                    self.multivariate_normal_sets[p]['index'] = i
+                    self.multivariate_normal_sets[p]["prior"] = mix[i]
+                    self.multivariate_normal_sets[p]["set"] = setname
+                    self.multivariate_normal_sets[p]["index"] = i
 
                 self.multivariate_normal_num_sets += 1
 
             # return required parameter
-            return self.multivariate_normal_sets[key]['prior']
+            return self.multivariate_normal_sets[key]["prior"]
 
         else:
-            raise ValueError("Prior for '{}' is not a MultivariateGaussian".format(key))
+            raise ValueError(f"Prior for '{key}' is not a MultivariateGaussian")
 
     def run_sampler(self):
         # set the step method
         pymc3, STEP_METHODS, floatX = self._import_external_sampler()
         step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
-        if 'step' in self._kwargs:
-            self.step_method = self._kwargs.pop('step')
+        if "step" in self._kwargs:
+            self.step_method = self._kwargs.pop("step")
 
             # 'step' could be a dictionary of methods for different parameters,
             # so check for this
@@ -421,7 +496,9 @@ class Pymc3(MCMCSampler):
             elif isinstance(self.step_method, dict):
                 for key in self.step_method:
                     if key not in self._search_parameter_keys:
-                        raise ValueError("Setting a step method for an unknown parameter '{}'".format(key))
+                        raise ValueError(
+                            f"Setting a step method for an unknown parameter '{key}'"
+                        )
                     else:
                         # check if using a compound step (a list of step
                         # methods for a particular parameter)
@@ -431,7 +508,9 @@ class Pymc3(MCMCSampler):
                             sms = [self.step_method[key]]
                         for sm in sms:
                             if sm.lower() not in step_methods:
-                                raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
+                                raise ValueError(
+                                    f"Using invalid step method '{self.step_method[key]}'"
+                                )
             else:
                 # check if using a compound step (a list of step
                 # methods for a particular parameter)
@@ -442,7 +521,7 @@ class Pymc3(MCMCSampler):
 
                 for i in range(len(sms)):
                     if sms[i].lower() not in step_methods:
-                        raise ValueError("Using invalid step method '{}'".format(sms[i]))
+                        raise ValueError(f"Using invalid step method '{sms[i]}'")
         else:
             self.step_method = None
 
@@ -457,7 +536,7 @@ class Pymc3(MCMCSampler):
         # takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
         # the likelihood within that context manager
         likeargs = infer_args_from_method(self.likelihood.log_likelihood)
-        if 'sampler' in likeargs:
+        if "sampler" in likeargs:
             self.likelihood.log_likelihood(sampler=self)
         else:
             # set the likelihood function from predefined functions
@@ -498,7 +577,7 @@ class Pymc3(MCMCSampler):
         # set the step method
         if isinstance(self.step_method, dict):
             # create list of step methods (any not given will default to NUTS)
-            self.kwargs['step'] = []
+            self.kwargs["step"] = []
             with self.pymc3_model:
                 for key in self.step_method:
                     # check for a compound step list
@@ -506,13 +585,25 @@ class Pymc3(MCMCSampler):
                         for sms in self.step_method[key]:
                             curmethod = sms.lower()
                             methodslist.append(curmethod)
-                            nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
-                                                                   step_methods)
+                            nuts_kwargs = self._create_nuts_kwargs(
+                                curmethod,
+                                key,
+                                nuts_kwargs,
+                                pymc3,
+                                step_kwargs,
+                                step_methods,
+                            )
                     else:
                         curmethod = self.step_method[key].lower()
                         methodslist.append(curmethod)
-                        nuts_kwargs = self._create_nuts_kwargs(curmethod, key, nuts_kwargs, pymc3, step_kwargs,
-                                                               step_methods)
+                        nuts_kwargs = self._create_nuts_kwargs(
+                            curmethod,
+                            key,
+                            nuts_kwargs,
+                            pymc3,
+                            step_kwargs,
+                            step_methods,
+                        )
         else:
             with self.pymc3_model:
                 # check for a compound step list
@@ -521,28 +612,38 @@ class Pymc3(MCMCSampler):
                     for sms in self.step_method:
                         curmethod = sms.lower()
                         methodslist.append(curmethod)
-                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(
+                            curmethod, nuts_kwargs, step_kwargs
+                        )
                         compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
-                        self.kwargs['step'] = compound
+                        self.kwargs["step"] = compound
                 else:
-                    self.kwargs['step'] = None
+                    self.kwargs["step"] = None
                     if self.step_method is not None:
                         curmethod = self.step_method.lower()
                         methodslist.append(curmethod)
-                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(curmethod, nuts_kwargs, step_kwargs)
-                        self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
+                        args, nuts_kwargs = self._create_args_and_nuts_kwargs(
+                            curmethod, nuts_kwargs, step_kwargs
+                        )
+                        self.kwargs["step"] = pymc3.__dict__[step_methods[curmethod]](
+                            **args
+                        )
                     else:
                         # re-add step_kwargs if no step methods are set
-                        if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
-                            self.kwargs['step_kwargs'] = step_kwargs
+                        if len(step_kwargs) > 0 and StrictVersion(
+                            pymc3.__version__
+                        ) < StrictVersion("3.7"):
+                            self.kwargs["step_kwargs"] = step_kwargs
 
         # check whether only NUTS step method has been assigned
-        if np.all([sm.lower() == 'nuts' for sm in methodslist]):
+        if np.all([sm.lower() == "nuts" for sm in methodslist]):
             # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
-            self.kwargs['step'] = None
+            self.kwargs["step"] = None
 
-            if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
-                self.kwargs['nuts_kwargs'] = nuts_kwargs
+            if len(nuts_kwargs) > 0 and StrictVersion(
+                pymc3.__version__
+            ) < StrictVersion("3.7"):
+                self.kwargs["nuts_kwargs"] = nuts_kwargs
             elif len(nuts_kwargs) > 0:
                 # add NUTS kwargs to standard kwargs
                 self.kwargs.update(nuts_kwargs)
@@ -564,22 +665,27 @@ class Pymc3(MCMCSampler):
         return self.result
 
     def _create_args_and_nuts_kwargs(self, curmethod, nuts_kwargs, step_kwargs):
-        if curmethod == 'nuts':
+        if curmethod == "nuts":
             args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
         else:
             args = step_kwargs.get(curmethod, {})
         return args, nuts_kwargs
 
-    def _create_nuts_kwargs(self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods):
-        if curmethod == 'nuts':
+    def _create_nuts_kwargs(
+        self, curmethod, key, nuts_kwargs, pymc3, step_kwargs, step_methods
+    ):
+        if curmethod == "nuts":
             args, nuts_kwargs = self._get_nuts_args(nuts_kwargs, step_kwargs)
         else:
             if step_kwargs is not None:
                 args = step_kwargs.get(curmethod, {})
             else:
                 args = {}
-        self.kwargs['step'].append(
-            pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
+        self.kwargs["step"].append(
+            pymc3.__dict__[step_methods[curmethod]](
+                vars=[self.pymc3_priors[key]], **args
+            )
+        )
         return nuts_kwargs
 
     @staticmethod
@@ -587,7 +693,7 @@ class Pymc3(MCMCSampler):
         if nuts_kwargs is not None:
             args = nuts_kwargs
         elif step_kwargs is not None:
-            args = step_kwargs.pop('nuts', {})
+            args = step_kwargs.pop("nuts", {})
             # add values into nuts_kwargs
             nuts_kwargs = args
         else:
@@ -618,55 +724,85 @@ class Pymc3(MCMCSampler):
                 # if the prior contains ln_prob method that takes a 'sampler' argument
                 # then try using that
                 lnprobargs = infer_args_from_method(self.priors[key].ln_prob)
-                if 'sampler' in lnprobargs:
+                if "sampler" in lnprobargs:
                     try:
                         self.pymc3_priors[key] = self.priors[key].ln_prob(sampler=self)
                     except RuntimeError:
-                        raise RuntimeError(("Problem setting PyMC3 prior for ",
-                                            "'{}'".format(key)))
+                        raise RuntimeError(
+                            ("Problem setting PyMC3 prior for ", f"'{key}'")
+                        )
                 else:
                     # use Prior distribution name
                     distname = self.priors[key].__class__.__name__
 
                     if distname in self.prior_map:
                         # check if we have a predefined PyMC3 distribution
-                        if 'pymc3' in self.prior_map[distname] and 'argmap' in self.prior_map[distname]:
+                        if (
+                            "pymc3" in self.prior_map[distname]
+                            and "argmap" in self.prior_map[distname]
+                        ):
                             # check the required arguments for the PyMC3 distribution
-                            pymc3distname = self.prior_map[distname]['pymc3']
+                            pymc3distname = self.prior_map[distname]["pymc3"]
 
                             if pymc3distname not in pymc3.__dict__:
-                                raise ValueError("Prior '{}' is not a known PyMC3 distribution.".format(pymc3distname))
+                                raise ValueError(
+                                    f"Prior '{pymc3distname}' is not a known PyMC3 distribution."
+                                )
 
-                            reqargs = infer_args_from_method(pymc3.__dict__[pymc3distname].__init__)
+                            reqargs = infer_args_from_method(
+                                pymc3.__dict__[pymc3distname].__init__
+                            )
 
                             # set keyword arguments
                             priorkwargs = {}
-                            for (targ, parg) in self.prior_map[distname]['argmap'].items():
+                            for (targ, parg) in self.prior_map[distname][
+                                "argmap"
+                            ].items():
                                 if hasattr(self.priors[key], targ):
                                     if parg in reqargs:
-                                        if 'argtransform' in self.prior_map[distname]:
-                                            if targ in self.prior_map[distname]['argtransform']:
-                                                tfunc = self.prior_map[distname]['argtransform'][targ]
+                                        if "argtransform" in self.prior_map[distname]:
+                                            if (
+                                                targ
+                                                in self.prior_map[distname][
+                                                    "argtransform"
+                                                ]
+                                            ):
+                                                tfunc = self.prior_map[distname][
+                                                    "argtransform"
+                                                ][targ]
                                             else:
+
                                                 def tfunc(x):
                                                     return x
+
                                         else:
+
                                             def tfunc(x):
                                                 return x
 
-                                        priorkwargs[parg] = tfunc(getattr(self.priors[key], targ))
+                                        priorkwargs[parg] = tfunc(
+                                            getattr(self.priors[key], targ)
+                                        )
                                     else:
-                                        raise ValueError("Unknown argument {}".format(parg))
+                                        raise ValueError(f"Unknown argument {parg}")
                                 else:
                                     if parg in reqargs:
                                         priorkwargs[parg] = None
-                            self.pymc3_priors[key] = pymc3.__dict__[pymc3distname](key, **priorkwargs)
-                        elif 'internal' in self.prior_map[distname]:
-                            self.pymc3_priors[key] = self.prior_map[distname]['internal'](key)
+                            self.pymc3_priors[key] = pymc3.__dict__[pymc3distname](
+                                key, **priorkwargs
+                            )
+                        elif "internal" in self.prior_map[distname]:
+                            self.pymc3_priors[key] = self.prior_map[distname][
+                                "internal"
+                            ](key)
                         else:
-                            raise ValueError("Prior '{}' is not a known distribution.".format(distname))
+                            raise ValueError(
+                                f"Prior '{distname}' is not a known distribution."
+                            )
                     else:
-                        raise ValueError("Prior '{}' is not a known distribution.".format(distname))
+                        raise ValueError(
+                            f"Prior '{distname}' is not a known distribution."
+                        )
 
     def set_likelihood(self):
         """
@@ -692,17 +828,19 @@ class Pymc3(MCMCSampler):
                     if isinstance(self.priors[key], float):
                         self.likelihood.parameters[key] = self.priors[key]
 
-                self.logpgrad = LogLikeGrad(self.parameters, self.likelihood, self.priors)
+                self.logpgrad = LogLikeGrad(
+                    self.parameters, self.likelihood, self.priors
+                )
 
             def perform(self, node, inputs, outputs):
-                theta, = inputs
+                (theta,) = inputs
                 for i, key in enumerate(self.parameters):
                     self.likelihood.parameters[key] = theta[i]
 
                 outputs[0][0] = np.array(self.likelihood.log_likelihood())
 
             def grad(self, inputs, g):
-                theta, = inputs
+                (theta,) = inputs
                 return [g[0] * self.logpgrad(theta)]
 
         # create theano Op for calculating the gradient of the log likelihood
@@ -723,7 +861,7 @@ class Pymc3(MCMCSampler):
                         self.likelihood.parameters[key] = self.priors[key]
 
             def perform(self, node, inputs, outputs):
-                theta, = inputs
+                (theta,) = inputs
 
                 # define version of likelihood function to pass to derivative function
                 def lnlike(values):
@@ -732,7 +870,9 @@ class Pymc3(MCMCSampler):
                     return self.likelihood.log_likelihood()
 
                 # calculate gradients
-                grads = derivatives(theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2)
+                grads = derivatives(
+                    theta, lnlike, abseps=1e-5, mineps=1e-12, reltol=1e-2
+                )
 
                 outputs[0][0] = grads
 
@@ -740,84 +880,114 @@ class Pymc3(MCMCSampler):
             #  check if it is a predefined likelhood function
             if isinstance(self.likelihood, GaussianLikelihood):
                 # check required attributes exist
-                if (not hasattr(self.likelihood, 'sigma') or
-                        not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Gaussian Likelihood does not have all the correct attributes!")
-
-                if 'sigma' in self.pymc3_priors:
+                if (
+                    not hasattr(self.likelihood, "sigma")
+                    or not hasattr(self.likelihood, "x")
+                    or not hasattr(self.likelihood, "y")
+                ):
+                    raise ValueError(
+                        "Gaussian Likelihood does not have all the correct attributes!"
+                    )
+
+                if "sigma" in self.pymc3_priors:
                     # if sigma is suppled use that value
                     if self.likelihood.sigma is None:
-                        self.likelihood.sigma = self.pymc3_priors.pop('sigma')
+                        self.likelihood.sigma = self.pymc3_priors.pop("sigma")
                     else:
-                        del self.pymc3_priors['sigma']
+                        del self.pymc3_priors["sigma"]
 
                 for key in self.pymc3_priors:
                     if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
 
                 # set the distribution
-                pymc3.Normal('likelihood', mu=model, sd=self.likelihood.sigma,
-                             observed=self.likelihood.y)
+                pymc3.Normal(
+                    "likelihood",
+                    mu=model,
+                    sd=self.likelihood.sigma,
+                    observed=self.likelihood.y,
+                )
             elif isinstance(self.likelihood, PoissonLikelihood):
                 # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Poisson Likelihood does not have all the correct attributes!")
+                if not hasattr(self.likelihood, "x") or not hasattr(
+                    self.likelihood, "y"
+                ):
+                    raise ValueError(
+                        "Poisson Likelihood does not have all the correct attributes!"
+                    )
 
                 for key in self.pymc3_priors:
                     if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 # get rate function
                 model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
 
                 # set the distribution
-                pymc3.Poisson('likelihood', mu=model, observed=self.likelihood.y)
+                pymc3.Poisson("likelihood", mu=model, observed=self.likelihood.y)
             elif isinstance(self.likelihood, ExponentialLikelihood):
                 # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y')):
-                    raise ValueError("Exponential Likelihood does not have all the correct attributes!")
+                if not hasattr(self.likelihood, "x") or not hasattr(
+                    self.likelihood, "y"
+                ):
+                    raise ValueError(
+                        "Exponential Likelihood does not have all the correct attributes!"
+                    )
 
                 for key in self.pymc3_priors:
                     if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 # get mean function
                 model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
 
                 # set the distribution
-                pymc3.Exponential('likelihood', lam=1. / model, observed=self.likelihood.y)
+                pymc3.Exponential(
+                    "likelihood", lam=1.0 / model, observed=self.likelihood.y
+                )
             elif isinstance(self.likelihood, StudentTLikelihood):
                 # check required attributes exist
-                if (not hasattr(self.likelihood, 'x') or
-                        not hasattr(self.likelihood, 'y') or
-                        not hasattr(self.likelihood, 'nu') or
-                        not hasattr(self.likelihood, 'sigma')):
-                    raise ValueError("StudentT Likelihood does not have all the correct attributes!")
-
-                if 'nu' in self.pymc3_priors:
+                if (
+                    not hasattr(self.likelihood, "x")
+                    or not hasattr(self.likelihood, "y")
+                    or not hasattr(self.likelihood, "nu")
+                    or not hasattr(self.likelihood, "sigma")
+                ):
+                    raise ValueError(
+                        "StudentT Likelihood does not have all the correct attributes!"
+                    )
+
+                if "nu" in self.pymc3_priors:
                     # if nu is suppled use that value
                     if self.likelihood.nu is None:
-                        self.likelihood.nu = self.pymc3_priors.pop('nu')
+                        self.likelihood.nu = self.pymc3_priors.pop("nu")
                     else:
-                        del self.pymc3_priors['nu']
+                        del self.pymc3_priors["nu"]
 
                 for key in self.pymc3_priors:
                     if key not in self.likelihood.function_keys:
-                        raise ValueError("Prior key '{}' is not a function key!".format(key))
+                        raise ValueError(f"Prior key '{key}' is not a function key!")
 
                 model = self.likelihood.func(self.likelihood.x, **self.pymc3_priors)
 
                 # set the distribution
-                pymc3.StudentT('likelihood', nu=self.likelihood.nu, mu=model, sd=self.likelihood.sigma,
-                               observed=self.likelihood.y)
-            elif isinstance(self.likelihood, (GravitationalWaveTransient, BasicGravitationalWaveTransient)):
+                pymc3.StudentT(
+                    "likelihood",
+                    nu=self.likelihood.nu,
+                    mu=model,
+                    sd=self.likelihood.sigma,
+                    observed=self.likelihood.y,
+                )
+            elif isinstance(
+                self.likelihood,
+                (GravitationalWaveTransient, BasicGravitationalWaveTransient),
+            ):
                 # set theano Op - pass _search_parameter_keys, which only contains non-fixed variables
-                logl = LogLike(self._search_parameter_keys, self.likelihood, self.pymc3_priors)
+                logl = LogLike(
+                    self._search_parameter_keys, self.likelihood, self.pymc3_priors
+                )
 
                 parameters = dict()
                 for key in self._search_parameter_keys:
@@ -825,11 +995,14 @@ class Pymc3(MCMCSampler):
                         parameters[key] = self.pymc3_priors[key]
                     except KeyError:
                         raise KeyError(
-                            "Unknown key '{}' when setting GravitationalWaveTransient likelihood".format(key))
+                            f"Unknown key '{key}' when setting GravitationalWaveTransient likelihood"
+                        )
 
                 # convert to theano tensor variable
                 values = tt.as_tensor_variable(list(parameters.values()))
 
-                pymc3.DensityDist('likelihood', lambda v: logl(v), observed={'v': values})
+                pymc3.DensityDist(
+                    "likelihood", lambda v: logl(v), observed={"v": values}
+                )
             else:
                 raise ValueError("Unknown likelihood has been provided")
diff --git a/bilby/core/sampler/pymultinest.py b/bilby/core/sampler/pymultinest.py
index d9869362256ad60c8029acdd0e19020a31dea8f8..da6e7a9778273f0bb25a2702bddcbc2fd18f4ae3 100644
--- a/bilby/core/sampler/pymultinest.py
+++ b/bilby/core/sampler/pymultinest.py
@@ -1,20 +1,15 @@
+import datetime
 import importlib
 import os
-import shutil
-import distutils.dir_util
-import signal
 import time
-import datetime
-import sys
 
 import numpy as np
 
-from ..utils import check_directory_exists_and_if_not_mkdir
 from ..utils import logger
-from .base_sampler import NestedSampler
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
-class Pymultinest(NestedSampler):
+class Pymultinest(_TemporaryFileSamplerMixin, NestedSampler):
     """
     bilby wrapper of pymultinest
     (https://github.com/JohannesBuchner/PyMultiNest)
@@ -65,6 +60,8 @@ class Pymultinest(NestedSampler):
         init_MPI=False,
         dump_callback=None,
     )
+    short_name = "pm"
+    hard_exit = True
 
     def __init__(
         self,
@@ -94,6 +91,7 @@ class Pymultinest(NestedSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
+            temporary_directory=temporary_directory,
             **kwargs
         )
         self._apply_multinest_boundaries()
@@ -105,10 +103,6 @@ class Pymultinest(NestedSampler):
             )
         self.use_temporary_directory = temporary_directory and not using_mpi
 
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
     def _translate_kwargs(self, kwargs):
         if "n_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
@@ -141,74 +135,7 @@ class Pymultinest(NestedSampler):
                 else:
                     self.kwargs["wrapped_params"].append(0)
 
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = "{}/pm_{}/".format(self.outdir, self.label)
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """Write current state and exit on exit_code"""
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        sys.exit(self.exit_code)
-
-    def _copy_temporary_directory_contents_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path.
-        Do not delete the temporary directory.
-        """
-        logger.info(
-            "Overwriting {} with {}".format(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-        )
-        if self.outputfiles_basename.endswith("/"):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(
-            self.temporary_outputfiles_basename, outputfiles_basename_stripped
-        )
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Copy the temporary back to the proper path
-
-        Anything in the temporary directory at this point is removed
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
-
+    @signal_wrapper
     def run_sampler(self):
         import pymultinest
 
@@ -247,27 +174,6 @@ class Pymultinest(NestedSampler):
         self.result.nested_samples = self._nested_samples
         return self.result
 
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = self.kwargs["outputfiles_basename"] + "/sampling_time.dat"
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, "r") as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-        self.start_time = current_time
-        with open(self.time_file_path, "w") as time_file:
-            time_file.write(str(self.total_sampling_time))
-
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["outputfiles_basename"] = self.outputfiles_basename
-
     @property
     def _nested_samples(self):
         """
diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py
index 2348319e4b6048eb2669e798e94c6a06af9b6e0e..fc70b38ad4f0b04169254e8031c0a996087abdd2 100644
--- a/bilby/core/sampler/ultranest.py
+++ b/bilby/core/sampler/ultranest.py
@@ -1,20 +1,15 @@
-
 import datetime
-import distutils.dir_util
 import inspect
-import os
-import shutil
-import signal
 import time
 
 import numpy as np
 from pandas import DataFrame
 
-from ..utils import check_directory_exists_and_if_not_mkdir, logger
-from .base_sampler import NestedSampler
+from ..utils import logger
+from .base_sampler import NestedSampler, _TemporaryFileSamplerMixin, signal_wrapper
 
 
-class Ultranest(NestedSampler):
+class Ultranest(_TemporaryFileSamplerMixin, NestedSampler):
     """
     bilby wrapper of ultranest
     (https://johannesbuchner.github.io/UltraNest/index.html)
@@ -73,6 +68,8 @@ class Ultranest(NestedSampler):
         step_sampler=None,
     )
 
+    short_name = "ultra"
+
     def __init__(
         self,
         likelihood,
@@ -96,31 +93,31 @@ class Ultranest(NestedSampler):
             plot=plot,
             skip_import_verification=skip_import_verification,
             exit_code=exit_code,
+            temporary_directory=temporary_directory,
             **kwargs,
         )
         self._apply_ultranest_boundaries()
-        self.use_temporary_directory = temporary_directory
 
         if self.use_temporary_directory:
             # set callback interval, so copying of results does not thrash the
             # disk (ultranest will call viz_callback quite a lot)
             self.callback_interval = callback_interval
 
-        signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
-        signal.signal(signal.SIGINT, self.write_current_state_and_exit)
-        signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
-
     def _translate_kwargs(self, kwargs):
         if "num_live_points" not in kwargs:
             for equiv in self.npoints_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["num_live_points"] = kwargs.pop(equiv)
-
         if "verbose" in kwargs and "show_status" not in kwargs:
             kwargs["show_status"] = kwargs.pop("verbose")
+        resume = kwargs.get("resume", False)
+        if resume is True:
+            kwargs["resume"] = "overwrite"
+        elif resume is False:
+            kwargs["resume"] = "overwrite"
 
     def _verify_kwargs_against_default_kwargs(self):
-        """ Check the kwargs """
+        """Check the kwargs"""
 
         self.outputfiles_basename = self.kwargs.pop("log_dir", None)
         if self.kwargs["viz_callback"] is None:
@@ -148,76 +145,13 @@ class Ultranest(NestedSampler):
                     else:
                         self.kwargs["wrapped_params"].append(0)
 
-    @property
-    def outputfiles_basename(self):
-        return self._outputfiles_basename
-
-    @outputfiles_basename.setter
-    def outputfiles_basename(self, outputfiles_basename):
-        if outputfiles_basename is None:
-            outputfiles_basename = os.path.join(
-                self.outdir, "ultra_{}/".format(self.label)
-            )
-        if not outputfiles_basename.endswith("/"):
-            outputfiles_basename += "/"
-        check_directory_exists_and_if_not_mkdir(self.outdir)
-        self._outputfiles_basename = outputfiles_basename
-
-    @property
-    def temporary_outputfiles_basename(self):
-        return self._temporary_outputfiles_basename
-
-    @temporary_outputfiles_basename.setter
-    def temporary_outputfiles_basename(self, temporary_outputfiles_basename):
-        if not temporary_outputfiles_basename.endswith("/"):
-            temporary_outputfiles_basename = "{}/".format(
-                temporary_outputfiles_basename
-            )
-        self._temporary_outputfiles_basename = temporary_outputfiles_basename
-        if os.path.exists(self.outputfiles_basename):
-            shutil.copytree(
-                self.outputfiles_basename, self.temporary_outputfiles_basename
-            )
-
-    def write_current_state_and_exit(self, signum=None, frame=None):
-        """ Write current state and exit on exit_code """
-        logger.info(
-            "Run interrupted by signal {}: checkpoint and exit on {}".format(
-                signum, self.exit_code
-            )
-        )
-        self._calculate_and_save_sampling_time()
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-        os._exit(self.exit_code)
-
     def _copy_temporary_directory_contents_to_proper_path(self):
         """
         Copy the temporary back to the proper path.
         Do not delete the temporary directory.
         """
         if inspect.stack()[1].function != "_viz_callback":
-            logger.info(
-                "Overwriting {} with {}".format(
-                    self.outputfiles_basename, self.temporary_outputfiles_basename
-                )
-            )
-        if self.outputfiles_basename.endswith("/"):
-            outputfiles_basename_stripped = self.outputfiles_basename[:-1]
-        else:
-            outputfiles_basename_stripped = self.outputfiles_basename
-        distutils.dir_util.copy_tree(
-            self.temporary_outputfiles_basename, outputfiles_basename_stripped
-        )
-
-    def _move_temporary_directory_to_proper_path(self):
-        """
-        Move the temporary back to the proper path
-
-        Anything in the proper path at this point is removed including links
-        """
-        self._copy_temporary_directory_contents_to_proper_path()
-        shutil.rmtree(self.temporary_outputfiles_basename)
+            super(Ultranest, self)._copy_temporary_directory_contents_to_proper_path()
 
     @property
     def sampler_function_kwargs(self):
@@ -271,6 +205,7 @@ class Ultranest(NestedSampler):
 
         return init_kwargs
 
+    @signal_wrapper
     def run_sampler(self):
         import ultranest
         import ultranest.stepsampler
@@ -285,7 +220,7 @@ class Ultranest(NestedSampler):
         stepsampler = self.kwargs.pop("step_sampler", None)
 
         self._setup_run_directory()
-        self.kwargs["log_dir"] = self.kwargs.pop("outputfiles_basename")
+        self.kwargs["log_dir"] = self.kwargs["outputfiles_basename"]
         self._check_and_load_sampling_time_file()
 
         # use reactive nested sampler when no live points are given
@@ -317,7 +252,6 @@ class Ultranest(NestedSampler):
         results = sampler.run(**self.sampler_function_kwargs)
         self._calculate_and_save_sampling_time()
 
-        # Clean up
         self._clean_up_run_directory()
 
         self._generate_result(results)
@@ -325,27 +259,6 @@ class Ultranest(NestedSampler):
 
         return self.result
 
-    def _clean_up_run_directory(self):
-        if self.use_temporary_directory:
-            self._move_temporary_directory_to_proper_path()
-            self.kwargs["log_dir"] = self.outputfiles_basename
-
-    def _check_and_load_sampling_time_file(self):
-        self.time_file_path = os.path.join(self.kwargs["log_dir"], "sampling_time.dat")
-        if os.path.exists(self.time_file_path):
-            with open(self.time_file_path, "r") as time_file:
-                self.total_sampling_time = float(time_file.readline())
-        else:
-            self.total_sampling_time = 0
-
-    def _calculate_and_save_sampling_time(self):
-        current_time = time.time()
-        new_sampling_time = current_time - self.start_time
-        self.total_sampling_time += new_sampling_time
-        with open(self.time_file_path, "w") as time_file:
-            time_file.write(str(self.total_sampling_time))
-        self.start_time = current_time
-
     def _generate_result(self, out):
         # extract results
         data = np.array(out["weighted_samples"]["points"])
@@ -357,16 +270,22 @@ class Ultranest(NestedSampler):
         nested_samples = DataFrame(data, columns=self.search_parameter_keys)
         nested_samples["weights"] = weights
         nested_samples["log_likelihood"] = out["weighted_samples"]["logl"]
-        self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["logl"])[
-            mask
-        ]
+        self.result.log_likelihood_evaluations = np.array(
+            out["weighted_samples"]["logl"]
+        )[mask]
         self.result.sampler_output = out
         self.result.samples = data[mask, :]
         self.result.nested_samples = nested_samples
         self.result.log_evidence = out["logz"]
         self.result.log_evidence_err = out["logzerr"]
         if self.kwargs["num_live_points"] is not None:
-            self.result.information_gain = np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
+            self.result.information_gain = (
+                np.power(out["logzerr"], 2) * self.kwargs["num_live_points"]
+            )
 
         self.result.outputfiles_basename = self.outputfiles_basename
         self.result.sampling_time = datetime.timedelta(seconds=self.total_sampling_time)
+
+    def log_likelihood(self, theta):
+        log_l = super(Ultranest, self).log_likelihood(theta=theta)
+        return np.nan_to_num(log_l)
diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py
index 78c3529ea00c1c9ee518a8698e2f70bc29fe194d..c7ae40da222201e5b29c53635c16c3edc94744f0 100644
--- a/bilby/core/sampler/zeus.py
+++ b/bilby/core/sampler/zeus.py
@@ -1,18 +1,17 @@
 import os
-import signal
 import shutil
-import sys
-from collections import namedtuple
 from shutil import copyfile
 
 import numpy as np
-from pandas import DataFrame
 
-from ..utils import logger, check_directory_exists_and_if_not_mkdir
-from .base_sampler import MCMCSampler, SamplerError
+from .base_sampler import SamplerError, signal_wrapper
+from .emcee import Emcee
+from .ptemcee import LikePriorEvaluator
 
+_evaluator = LikePriorEvaluator()
 
-class Zeus(MCMCSampler):
+
+class Zeus(Emcee):
     """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/)
 
     All positional and keyword arguments (i.e., the args and kwargs) passed to
@@ -65,12 +64,8 @@ class Zeus(MCMCSampler):
         burn_in_fraction=0.25,
         resume=True,
         burn_in_act=3,
-        **kwargs
+        **kwargs,
     ):
-        import zeus
-
-        self.zeus = zeus
-
         super(Zeus, self).__init__(
             likelihood=likelihood,
             priors=priors,
@@ -79,25 +74,16 @@ class Zeus(MCMCSampler):
             use_ratio=use_ratio,
             plot=plot,
             skip_import_verification=skip_import_verification,
-            **kwargs
+            pos0=pos0,
+            nburn=nburn,
+            burn_in_fraction=burn_in_fraction,
+            resume=resume,
+            burn_in_act=burn_in_act,
+            **kwargs,
         )
-        self.resume = resume
-        self.pos0 = pos0
-        self.nburn = nburn
-        self.burn_in_fraction = burn_in_fraction
-        self.burn_in_act = burn_in_act
-
-        signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
-        signal.signal(signal.SIGINT, self.checkpoint_and_exit)
 
     def _translate_kwargs(self, kwargs):
-        if "nwalkers" not in kwargs:
-            for equiv in self.nwalkers_equiv_kwargs:
-                if equiv in kwargs:
-                    kwargs["nwalkers"] = kwargs.pop(equiv)
-        if "iterations" not in kwargs:
-            if "nsteps" in kwargs:
-                kwargs["iterations"] = kwargs.pop("nsteps")
+        super(Zeus, self)._translate_kwargs(kwargs=kwargs)
 
         # check if using emcee-style arguments
         if "start" not in kwargs:
@@ -107,17 +93,6 @@ class Zeus(MCMCSampler):
             if "lnprob0" in kwargs:
                 kwargs["log_prob0"] = kwargs.pop("lnprob0")
 
-        if "threads" in kwargs:
-            if kwargs["threads"] != 1:
-                logger.warning(
-                    "The 'threads' argument cannot be used for "
-                    "parallelisation. This run will proceed "
-                    "without parallelisation, but consider the use "
-                    "of an appropriate Pool object passed to the "
-                    "'pool' keyword."
-                )
-                kwargs["threads"] = 1
-
     @property
     def sampler_function_kwargs(self):
         keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"]
@@ -134,168 +109,21 @@ class Zeus(MCMCSampler):
             if key not in self.sampler_function_kwargs
         }
 
-        init_kwargs["logprob_fn"] = self.lnpostfn
+        init_kwargs["logprob_fn"] = _evaluator.call_emcee
         init_kwargs["ndim"] = self.ndim
 
         return init_kwargs
 
-    def lnpostfn(self, theta):
-        log_prior = self.log_prior(theta)
-        if np.isinf(log_prior):
-            return -np.inf, [np.nan, np.nan]
-        else:
-            log_likelihood = self.log_likelihood(theta)
-            return log_likelihood + log_prior, [log_likelihood, log_prior]
-
-    @property
-    def nburn(self):
-        if type(self.__nburn) in [float, int]:
-            return int(self.__nburn)
-        elif self.result.max_autocorrelation_time is None:
-            return int(self.burn_in_fraction * self.nsteps)
-        else:
-            return int(self.burn_in_act * self.result.max_autocorrelation_time)
-
-    @nburn.setter
-    def nburn(self, nburn):
-        if isinstance(nburn, (float, int)):
-            if nburn > self.kwargs["iterations"] - 1:
-                raise ValueError(
-                    "Number of burn-in samples must be smaller "
-                    "than the total number of iterations"
-                )
-
-        self.__nburn = nburn
-
-    @property
-    def nwalkers(self):
-        return self.kwargs["nwalkers"]
-
-    @property
-    def nsteps(self):
-        return self.kwargs["iterations"]
-
-    @nsteps.setter
-    def nsteps(self, nsteps):
-        self.kwargs["iterations"] = nsteps
-
-    @property
-    def stored_chain(self):
-        """Read the stored zero-temperature chain data in from disk"""
-        return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
-
-    @property
-    def stored_samples(self):
-        """Returns the samples stored on disk"""
-        return self.stored_chain[self.search_parameter_keys]
-
-    @property
-    def stored_loglike(self):
-        """Returns the log-likelihood stored on disk"""
-        return self.stored_chain["log_l"]
-
-    @property
-    def stored_logprior(self):
-        """Returns the log-prior stored on disk"""
-        return self.stored_chain["log_p"]
-
-    def _init_chain_file(self):
-        with open(self.checkpoint_info.chain_file, "w+") as ff:
-            ff.write(
-                "walker\t{}\tlog_l\tlog_p\n".format(
-                    "\t".join(self.search_parameter_keys)
-                )
-            )
-
-    @property
-    def checkpoint_info(self):
-        """Defines various things related to checkpointing and storing data
-
-        Returns
-        =======
-        checkpoint_info: named_tuple
-            An object with attributes `sampler_file`, `chain_file`, and
-            `chain_template`. The first two give paths to where the sampler and
-            chain data is stored, the last a formatted-str-template with which
-            to write the chain data to disk
-
-        """
-        out_dir = os.path.join(
-            self.outdir, "{}_{}".format(self.__class__.__name__.lower(), self.label)
-        )
-        check_directory_exists_and_if_not_mkdir(out_dir)
-
-        chain_file = os.path.join(out_dir, "chain.dat")
-        sampler_file = os.path.join(out_dir, "sampler.pickle")
-        chain_template = (
-            "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n"
-        )
-
-        CheckpointInfo = namedtuple(
-            "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"]
-        )
-
-        checkpoint_info = CheckpointInfo(
-            sampler_file=sampler_file,
-            chain_file=chain_file,
-            chain_template=chain_template,
-        )
-
-        return checkpoint_info
-
-    @property
-    def sampler_chain(self):
-        nsteps = self._previous_iterations
-        return self.sampler.chain[:, :nsteps, :]
-
-    def checkpoint(self):
-        """Writes a pickle file of the sampler to disk using dill"""
-        import dill
-
-        logger.info(
-            "Checkpointing sampler to file {}".format(self.checkpoint_info.sampler_file)
-        )
-        with open(self.checkpoint_info.sampler_file, "wb") as f:
-            # Overwrites the stored sampler chain with one that is truncated
-            # to only the completed steps
-            self.sampler._chain = self.sampler_chain
-            dill.dump(self._sampler, f)
-
-    def checkpoint_and_exit(self, signum, frame):
-        logger.info("Received signal {}".format(signum))
-        self.checkpoint()
-        sys.exit()
+    def write_current_state(self):
+        self._sampler.distribute = map
+        super(Zeus, self).write_current_state()
+        self._sampler.distribute = getattr(self._sampler.pool, "map", map)
 
     def _initialise_sampler(self):
-        self._sampler = self.zeus.EnsembleSampler(**self.sampler_init_kwargs)
-        self._init_chain_file()
+        from zeus import EnsembleSampler
 
-    @property
-    def sampler(self):
-        """Returns the Zeus sampler object
-
-        If, already initialized, returns the stored _sampler value. Otherwise,
-        first checks if there is a pickle file from which to load. If there is
-        not, then initialize the sampler and set the initial random draw
-
-        """
-        if hasattr(self, "_sampler"):
-            pass
-        elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
-            import dill
-
-            logger.info(
-                "Resuming run from checkpoint file {}".format(
-                    self.checkpoint_info.sampler_file
-                )
-            )
-            with open(self.checkpoint_info.sampler_file, "rb") as f:
-                self._sampler = dill.load(f)
-            self._set_pos0_for_resume()
-        else:
-            self._initialise_sampler()
-            self._set_pos0()
-        return self._sampler
+        self._sampler = EnsembleSampler(**self.sampler_init_kwargs)
+        self._init_chain_file()
 
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
@@ -310,48 +138,12 @@ class Zeus(MCMCSampler):
                 ff.write(self.checkpoint_info.chain_template.format(ii, *point))
         shutil.move(temp_chain_file, chain_file)
 
-    @property
-    def _previous_iterations(self):
-        """Returns the number of iterations that the sampler has saved
-
-        This is used when loading in a sampler from a pickle file to figure out
-        how much of the run has already been completed
-        """
-        try:
-            return len(self.sampler.get_blobs())
-        except AttributeError:
-            return 0
-
-    def _draw_pos0_from_prior(self):
-        return np.array(
-            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
-        )
-
-    @property
-    def _pos0_shape(self):
-        return (self.nwalkers, self.ndim)
-
-    def _set_pos0(self):
-        if self.pos0 is not None:
-            logger.debug("Using given initial positions for walkers")
-            if isinstance(self.pos0, DataFrame):
-                self.pos0 = self.pos0[self.search_parameter_keys].values
-            elif type(self.pos0) in (list, np.ndarray):
-                self.pos0 = np.squeeze(self.pos0)
-
-            if self.pos0.shape != self._pos0_shape:
-                raise ValueError("Input pos0 should be of shape ndim, nwalkers")
-            logger.debug("Checking input pos0")
-            for draw in self.pos0:
-                self.check_draw(draw)
-        else:
-            logger.debug("Generating initial walker positions from prior")
-            self.pos0 = self._draw_pos0_from_prior()
-
     def _set_pos0_for_resume(self):
         self.pos0 = self.sampler.get_last_sample()
 
+    @signal_wrapper
     def run_sampler(self):
+        self._setup_pool()
         sampler_function_kwargs = self.sampler_function_kwargs
         iterations = sampler_function_kwargs.pop("iterations")
         iterations -= self._previous_iterations
@@ -363,7 +155,8 @@ class Zeus(MCMCSampler):
             iterations=iterations, **sampler_function_kwargs
         ):
             self.write_chains_to_file(sample)
-        self.checkpoint()
+        self._close_pool()
+        self.write_current_state()
 
         self.result.sampler_output = np.nan
         self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
@@ -381,10 +174,12 @@ class Zeus(MCMCSampler):
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
-                "`nburn < nsteps` ({} < {}). Try increasing the "
-                "number of steps.".format(self.result.nburn, self.nsteps)
+                f"`nburn < nsteps` ({self.result.nburn} < {self.nsteps})."
+                " Try increasing the number of steps."
             )
-        blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape((-1, 2))
+        blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape(
+            (-1, 2)
+        )
         log_likelihoods, log_priors = blobs.T
         self.result.log_likelihood_evaluations = log_likelihoods
         self.result.log_prior_evaluations = log_priors
diff --git a/containers/dockerfile-template b/containers/dockerfile-template
index a594a9a996c33e50a2c873e42585f6269c3eda12..f4eca79d814edd4ecb9a6ba2e2fbd31bbeafa2ad 100644
--- a/containers/dockerfile-template
+++ b/containers/dockerfile-template
@@ -33,6 +33,7 @@ RUN pip install corner healpy cython tables
 RUN conda install -n ${{conda_env}} -c conda-forge dynesty emcee nestle ptemcee
 RUN conda install -n ${{conda_env}} -c conda-forge pymultinest ultranest
 RUN conda install -n ${{conda_env}} -c conda-forge cpnest kombine dnest4 zeus-mcmc
+RUN conda install -n ${{conda_env}} -c conda-forge ptmcmcsampler
 RUN conda install -n ${{conda_env}} -c conda-forge pytorch
 RUN conda install -n ${{conda_env}} -c conda-forge theano-pymc
 RUN conda install -n ${{conda_env}} -c conda-forge pymc3
@@ -49,10 +50,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${{conda_env}} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
diff --git a/containers/v3-dockerfile-test-suite-python38 b/containers/v3-dockerfile-test-suite-python38
index 04452eac65334c2b813219d66033802a98ed7215..4ee250e8b5792e2c96c78342a974c6a4fd8b97d9 100644
--- a/containers/v3-dockerfile-test-suite-python38
+++ b/containers/v3-dockerfile-test-suite-python38
@@ -35,6 +35,7 @@ RUN pip install corner healpy cython tables
 RUN conda install -n ${conda_env} -c conda-forge dynesty emcee nestle ptemcee
 RUN conda install -n ${conda_env} -c conda-forge pymultinest ultranest
 RUN conda install -n ${conda_env} -c conda-forge cpnest kombine dnest4 zeus-mcmc
+RUN conda install -n ${conda_env} -c conda-forge ptmcmcsampler
 RUN conda install -n ${conda_env} -c conda-forge pytorch
 RUN conda install -n ${conda_env} -c conda-forge theano-pymc
 RUN conda install -n ${conda_env} -c conda-forge pymc3
@@ -51,10 +52,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
diff --git a/containers/v3-dockerfile-test-suite-python39 b/containers/v3-dockerfile-test-suite-python39
index abaa39f235771b5489e2a8c647072599e7a93e1d..218b3d3e24b3b65b278a95958ffb0444b4f85d7c 100644
--- a/containers/v3-dockerfile-test-suite-python39
+++ b/containers/v3-dockerfile-test-suite-python39
@@ -35,6 +35,7 @@ RUN pip install corner healpy cython tables
 RUN conda install -n ${conda_env} -c conda-forge dynesty emcee nestle ptemcee
 RUN conda install -n ${conda_env} -c conda-forge pymultinest ultranest
 RUN conda install -n ${conda_env} -c conda-forge cpnest kombine dnest4 zeus-mcmc
+RUN conda install -n ${conda_env} -c conda-forge ptmcmcsampler
 RUN conda install -n ${conda_env} -c conda-forge pytorch
 RUN conda install -n ${conda_env} -c conda-forge theano-pymc
 RUN conda install -n ${conda_env} -c conda-forge pymc3
@@ -51,10 +52,6 @@ RUN apt-get install -y gfortran
 RUN git clone https://github.com/PolyChord/PolyChordLite.git \
 && (cd PolyChordLite && python setup.py --no-mpi install)
 
-# Install PTMCMCSampler
-RUN git clone https://github.com/jellis18/PTMCMCSampler.git \
-&& (cd PTMCMCSampler && python setup.py install)
-
 # Install GW packages
 RUN conda install -n ${conda_env} -c conda-forge python-lalsimulation bilby.cython
 RUN pip install ligo-gracedb gwpy ligo.skymap
diff --git a/test/bilby_mcmc/test_sampler.py b/test/bilby_mcmc/test_sampler.py
index aa52967da16cbfbf754279deed3c83139632d1e7..746eb1a9e1150732e93d5c31664751040e7b639c 100644
--- a/test/bilby_mcmc/test_sampler.py
+++ b/test/bilby_mcmc/test_sampler.py
@@ -3,7 +3,7 @@ import shutil
 import unittest
 
 import bilby
-from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler, _initialize_global_variables
+from bilby.bilby_mcmc.sampler import Bilby_MCMC, BilbyMCMCSampler
 from bilby.bilby_mcmc.utils import ConvergenceInputs
 from bilby.core.sampler.base_sampler import SamplerError
 import numpy as np
@@ -44,7 +44,12 @@ class TestBilbyMCMCSampler(unittest.TestCase):
         search_parameter_keys = ['m', 'c']
         use_ratio = False
 
-        _initialize_global_variables(likelihood, priors, search_parameter_keys, use_ratio)
+        bilby.core.sampler.base_sampler._initialize_global_variables(
+            likelihood,
+            priors,
+            search_parameter_keys,
+            use_ratio,
+        )
 
     def tearDown(self):
         if os.path.isdir(self.outdir):
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index 30be5e2ba542205d4cdbef10c6f8e9d681bcbbf1..3a1059e0dd82ac27b1897713b6b40f9f702bf644 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -1,7 +1,9 @@
 import copy
 import os
+import shutil
 import unittest
 from unittest.mock import MagicMock
+from parameterized import parameterized
 
 import numpy as np
 
@@ -102,22 +104,100 @@ class TestSampler(unittest.TestCase):
         self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None)
 
     def test_bad_value_np_abs_nan(self):
-        self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.abs(np.nan), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_abs_nan(self):
-        self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=abs(np.nan), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_pos_inf(self):
         self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None)
 
     def test_bad_value_neg_inf(self):
-        self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=-np.inf, warning=False, theta=None, label=None
+        )
 
     def test_bad_value_pos_inf_nan_to_num(self):
-        self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.nan_to_num(np.inf), warning=False, theta=None, label=None
+        )
 
     def test_bad_value_neg_inf_nan_to_num(self):
-        self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None)
+        self.sampler._check_bad_value(
+            val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None
+        )
+
+
+samplers = [
+    "bilby_mcmc",
+    "dynamic_dynesty",
+    "dynesty",
+    "emcee",
+    "kombine",
+    "ptemcee",
+    "zeus",
+]
+
+
+class GenericSamplerTest(unittest.TestCase):
+    def setUp(self):
+        self.likelihood = bilby.core.likelihood.Likelihood(dict())
+        self.priors = bilby.core.prior.PriorDict(
+            dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1))
+        )
+
+    def tearDown(self):
+        if os.path.isdir("outdir"):
+            shutil.rmtree("outdir")
+
+    @parameterized.expand(samplers)
+    def test_pool_creates_properly_no_pool(self, sampler_name):
+        sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler_name](
+            self.likelihood, self.priors
+        )
+        sampler._setup_pool()
+        if sampler_name == "kombine":
+            from kombine import SerialPool
+
+            self.assertIsInstance(sampler.pool, SerialPool)
+            pass
+        else:
+            self.assertIsNone(sampler.pool)
+
+    @parameterized.expand(samplers)
+    def test_pool_creates_properly_pool(self, sampler):
+        sampler = bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler](
+            self.likelihood, self.priors, npool=2
+        )
+        sampler._setup_pool()
+        if hasattr(sampler, "setup_sampler"):
+            sampler.setup_sampler()
+        self.assertEqual(sampler.pool._processes, 2)
+        sampler._close_pool()
+
+
+class ReorderLikelihoodsTest(unittest.TestCase):
+    def setUp(self):
+        self.unsorted_ln_likelihoods = np.array([1, 5, 2, 5, 1])
+        self.unsorted_samples = np.array([[0, 1], [1, 1], [1, 0], [0, 0], [0, 1]])
+        self.sorted_samples = np.array([[0, 1], [0, 1], [1, 0], [1, 1], [0, 0]])
+        self.sorted_ln_likelihoods = np.array([1, 1, 2, 5, 5])
+
+    def tearDown(self):
+        pass
+
+    def test_ordering(self):
+        func = bilby.core.sampler.base_sampler.NestedSampler.reorder_loglikelihoods
+        sorted_ln_likelihoods = func(
+            self.unsorted_ln_likelihoods, self.unsorted_samples, self.sorted_samples
+        )
+        self.assertTrue(
+            np.array_equal(sorted_ln_likelihoods, self.sorted_ln_likelihoods)
+        )
 
 
 if __name__ == "__main__":
diff --git a/test/core/sampler/ultranest_test.py b/test/core/sampler/ultranest_test.py
index dc578cd71932c877f0de8414361781cc86837789..be22c1a1f50b8d304000fcb8d0e4816e57c9c1b9 100644
--- a/test/core/sampler/ultranest_test.py
+++ b/test/core/sampler/ultranest_test.py
@@ -28,7 +28,7 @@ class TestUltranest(unittest.TestCase):
 
     def test_default_kwargs(self):
         expected = dict(
-            resume=True,
+            resume="overwrite",
             show_status=True,
             num_live_points=None,
             wrapped_params=None,
@@ -63,7 +63,7 @@ class TestUltranest(unittest.TestCase):
 
     def test_translate_kwargs(self):
         expected = dict(
-            resume=True,
+            resume="overwrite",
             show_status=True,
             num_live_points=123,
             wrapped_params=None,
diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py
index 2bf1d355e03cf48a91a3f9ff29b97ec47f10264e..3aa2157c04a6ad5d4008fde4a60710225e59bcb3 100644
--- a/test/integration/sampler_run_test.py
+++ b/test/integration/sampler_run_test.py
@@ -1,34 +1,100 @@
+import multiprocessing
+import os
+import sys
+import threading
+import time
+from signal import SIGINT
+
+multiprocessing.set_start_method("fork")  # noqa
+
 import unittest
 import pytest
+from parameterized import parameterized
 import shutil
 
 import bilby
 import numpy as np
 
 
+_sampler_kwargs = dict(
+    bilby_mcmc=dict(nsamples=200, printdt=1),
+    cpnest=dict(nlive=100),
+    dnest4=dict(
+        max_num_levels=2,
+        num_steps=10,
+        new_level_interval=10,
+        num_per_step=10,
+        thread_steps=1,
+        num_particles=50,
+        max_pool=1,
+    ),
+    dynesty=dict(nlive=100),
+    dynamic_dynesty=dict(
+        nlive_init=100,
+        nlive_batch=100,
+        dlogz_init=1.0,
+        maxbatch=0,
+        maxcall=100,
+        bound="single",
+    ),
+    emcee=dict(iterations=1000, nwalkers=10),
+    kombine=dict(iterations=200, nwalkers=10, autoburnin=False),
+    nessai=dict(
+        nlive=100,
+        poolsize=1000,
+        max_iteration=1000,
+        max_threads=3,
+    ),
+    nestle=dict(nlive=100),
+    ptemcee=dict(
+        nsamples=100,
+        nwalkers=50,
+        burn_in_act=1,
+        ntemps=1,
+        frac_threshold=0.5,
+    ),
+    PTMCMCSampler=dict(Niter=101, burn=2, isave=100),
+    # pymc3=dict(draws=50, tune=50, n_init=250),  removed until testing issue can be resolved
+    pymultinest=dict(nlive=100),
+    pypolychord=dict(nlive=100),
+    ultranest=dict(nlive=100, temporary_directory=False),
+)
+
+sampler_imports = dict(
+    bilby_mcmc="bilby",
+    dynamic_dynesty="dynesty"
+)
+
+no_pool_test = ["dnest4", "pymultinest", "nestle", "ptmcmcsampler", "pypolychord", "ultranest"]
+
+
+def slow_func(x, m, c):
+    time.sleep(0.01)
+    return m * x + c
+
+
+def model(x, m, c):
+    return m * x + c
+
+
 class TestRunningSamplers(unittest.TestCase):
     def setUp(self):
         np.random.seed(42)
         bilby.core.utils.command_line_args.bilby_test_mode = False
         self.x = np.linspace(0, 1, 11)
-        self.model = lambda x, m, c: m * x + c
         self.injection_parameters = dict(m=0.5, c=0.2)
         self.sigma = 0.1
-        self.y = self.model(self.x, **self.injection_parameters) + np.random.normal(
+        self.y = model(self.x, **self.injection_parameters) + np.random.normal(
             0, self.sigma, len(self.x)
         )
         self.likelihood = bilby.likelihood.GaussianLikelihood(
-            self.x, self.y, self.model, self.sigma
+            self.x, self.y, model, self.sigma
         )
 
         self.priors = bilby.core.prior.PriorDict()
         self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
         self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
-        self.kwargs = dict(
-            save=False,
-            conversion_function=self.conversion_function,
-            verbose=True,
-        )
+        self._remove_tree()
         bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
 
     @staticmethod
@@ -42,226 +108,83 @@ class TestRunningSamplers(unittest.TestCase):
         del self.likelihood
         del self.priors
         bilby.core.utils.command_line_args.bilby_test_mode = False
-        shutil.rmtree("outdir")
-
-    def test_run_cpnest(self):
-        pytest.importorskip("cpnest")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="cpnest",
-            nlive=100,
-            resume=False,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dnest4(self):
-        pytest.importorskip("dnest4")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dnest4",
-            max_num_levels=2,
-            num_steps=10,
-            new_level_interval=10,
-            num_per_step=10,
-            thread_steps=1,
-            num_particles=50,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dynesty(self):
-        pytest.importorskip("dynesty")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dynesty",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_dynamic_dynesty(self):
-        pytest.importorskip("dynesty")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="dynamic_dynesty",
-            nlive_init=100,
-            nlive_batch=100,
-            dlogz_init=1.0,
-            maxbatch=0,
-            maxcall=100,
-            bound="single",
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_emcee(self):
-        pytest.importorskip("emcee")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="emcee",
-            iterations=1000,
-            nwalkers=10,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_kombine(self):
-        pytest.importorskip("kombine")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="kombine",
-            iterations=2000,
-            nwalkers=20,
-            autoburnin=False,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_nestle(self):
-        pytest.importorskip("nestle")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="nestle",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_nessai(self):
-        pytest.importorskip("nessai")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="nessai",
-            nlive=100,
-            poolsize=1000,
-            max_iteration=1000,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_pypolychord(self):
-        pytest.importorskip("pypolychord")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pypolychord",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_ptemcee(self):
-        pytest.importorskip("ptemcee")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ptemcee",
-            nsamples=100,
-            nwalkers=50,
-            burn_in_act=1,
-            ntemps=1,
-            frac_threshold=0.5,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    @pytest.mark.xfail(
-        raises=AttributeError,
-        reason="Dependency issue with pymc3 causes attribute error on import",
-    )
-    def test_run_pymc3(self):
-        pytest.importorskip("pymc3")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pymc3",
-            draws=50,
-            tune=50,
-            n_init=250,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_pymultinest(self):
-        pytest.importorskip("pymultinest")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="pymultinest",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_PTMCMCSampler(self):
-        pytest.importorskip("PTMCMCSampler")
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="PTMCMCsampler",
-            Niter=101,
-            burn=2,
-            isave=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_ultranest(self):
-        pytest.importorskip("ultranest")
-        # run using NestedSampler (with nlive specified)
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ultranest",
-            nlive=100,
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-        # run using ReactiveNestedSampler (with no nlive given)
-        res = bilby.run_sampler(
-            likelihood=self.likelihood,
-            priors=self.priors,
-            sampler="ultranest",
-            **self.kwargs,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
-
-    def test_run_bilby_mcmc(self):
+        self._remove_tree()
+
+    def _remove_tree(self):
+        try:
+            shutil.rmtree("outdir")
+        except OSError:
+            pass
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_run_sampler_single(self, sampler):
+        self._run_sampler(sampler, pool_size=1)
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_run_sampler_pool(self, sampler):
+        self._run_sampler(sampler, pool_size=2)
+
+    def _run_sampler(self, sampler, pool_size, **extra_kwargs):
+        pytest.importorskip(sampler_imports.get(sampler, sampler))
+        if pool_size > 1 and sampler.lower() in no_pool_test:
+            pytest.skip(f"{sampler} cannot be parallelized")
+        bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
+        kwargs = _sampler_kwargs[sampler]
         res = bilby.run_sampler(
             likelihood=self.likelihood,
             priors=self.priors,
-            sampler="bilby_mcmc",
-            nsamples=200,
-            **self.kwargs,
-            printdt=1,
-        )
-        assert "derived" in res.posterior
-        assert res.log_likelihood_evaluations is not None
+            sampler=sampler,
+            save=False,
+            npool=pool_size,
+            conversion_function=self.conversion_function,
+            **kwargs,
+            **extra_kwargs,
+        )
+        assert "derived" in res.posterior
+        assert res.log_likelihood_evaluations is not None
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_interrupt_sampler_single(self, sampler):
+        self._run_with_signal_handling(sampler, pool_size=1)
+
+    @parameterized.expand(_sampler_kwargs.keys())
+    def test_interrupt_sampler_pool(self, sampler):
+        self._run_with_signal_handling(sampler, pool_size=2)
+
+    def _run_with_signal_handling(self, sampler, pool_size=1):
+        pytest.importorskip(sampler_imports.get(sampler, sampler))
+        if bilby.core.sampler.IMPLEMENTED_SAMPLERS[sampler.lower()].hard_exit:
+            pytest.skip(f"{sampler} hard exits, can't test signal handling.")
+        if pool_size > 1 and sampler.lower() in no_pool_test:
+            pytest.skip(f"{sampler} cannot be parallelized")
+        if sys.version_info.minor == 8 and sampler.lower == "cpnest":
+            pytest.skip("Pool interrupting broken for cpnest with py3.8")
+        if sampler.lower() == "nessai" and pool_size > 1:
+            pytest.skip(
+                "Interrupting with a pool is failing in pytest. "
+                "Likely due to interactions with the signal handling in nessai."
+            )
+        pid = os.getpid()
+        print(sampler)
+
+        def trigger_signal():
+            # You could do something more robust, e.g. wait until port is listening
+            time.sleep(4)
+            os.kill(pid, SIGINT)
+
+        thread = threading.Thread(target=trigger_signal)
+        thread.daemon = True
+        thread.start()
+
+        self.likelihood._func = slow_func
+
+        with self.assertRaises((SystemExit, KeyboardInterrupt)):
+            try:
+                while True:
+                    self._run_sampler(sampler=sampler, pool_size=pool_size, exit_code=5)
+            except SystemExit as error:
+                self.assertEqual(error.code, 5)
+                raise
 
 
 if __name__ == "__main__":