diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 6c1e169725221e4ad27facb7e7a7de517da4c6c9..07d070d0e71fe9b9fa01a93c1453f5dcb9264c80 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -1,10 +1,14 @@
 from __future__ import absolute_import, print_function
 
+from collections import namedtuple
 import os
+import signal
+import sys
 
 import numpy as np
 from pandas import DataFrame
 from distutils.version import LooseVersion
+import dill as pickle
 
 from ..utils import (
     logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
@@ -66,6 +70,9 @@ class Emcee(MCMCSampler):
         self.burn_in_fraction = burn_in_fraction
         self.burn_in_act = burn_in_act
 
+        signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
+        signal.signal(signal.SIGINT, self.checkpoint_and_exit)
+
     def _translate_kwargs(self, kwargs):
         if 'nwalkers' not in kwargs:
             for equiv in self.nwalkers_equiv_kwargs:
@@ -165,66 +172,139 @@ class Emcee(MCMCSampler):
     def nsteps(self, nsteps):
         self.kwargs['iterations'] = nsteps
 
-    def __getstate__(self):
-        # In order to be picklable with dill, we need to discard the pool
-        # object before trying.
-        d = self.__dict__
-        d["_Sampler__kwargs"]["pool"] = None
-        return d
+    @property
+    def stored_chain(self):
+        """ Read the stored zero-temperature chain data in from disk """
+        return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
 
-    def set_up_checkpoint(self):
-        out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
-        out_file = os.path.join(out_dir, 'chain.dat')
+    @property
+    def stored_samples(self):
+        """ Returns the samples stored on disk """
+        return self.stored_chain[self.search_parameter_keys]
 
-        if self.resume:
-            self.load_old_chain(out_file)
-        else:
-            self._set_pos0()
+    @property
+    def stored_loglike(self):
+        """ Returns the log-likelihood stored on disk """
+        return self.stored_chain['log_l']
 
+    @property
+    def stored_logprior(self):
+        """ Returns the log-prior stored on disk """
+        return self.stored_chain['log_p']
+
+    @property
+    def checkpoint_info(self):
+        """ Defines various things related to checkpointing and storing data
+
+        Returns
+        -------
+        checkpoint_info: named_tuple
+            An object with attributes `sampler_file`, `chain_file`, and
+            `chain_template`. The first two give paths to where the sampler and
+            chain data is stored, the last a formatted-str-template with which
+            to write the chain data to disk
+
+        """
+        out_dir = os.path.join(
+            self.outdir, '{}_{}'.format(self.__class__.__name__, self.label))
         check_directory_exists_and_if_not_mkdir(out_dir)
-        if not os.path.isfile(out_file):
-            with open(out_file, "w") as ff:
-                ff.write('walker\t{}\tlog_l\n'.format(
+
+        sampler_file = os.path.join(out_dir, 'sampler.pickle')
+
+        # Initialise chain file
+        chain_file = os.path.join(out_dir, 'chain.dat')
+        if not os.path.isfile(chain_file):
+            with open(chain_file, "w") as ff:
+                ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
                     '\t'.join(self.search_parameter_keys)))
-        template =\
+        chain_template =\
             '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
 
-        return out_file, template
+        CheckpointInfo = namedtuple(
+            'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
 
-    def run_sampler(self):
+        checkpoint_info = CheckpointInfo(
+            sampler_file=sampler_file, chain_file=chain_file,
+            chain_template=chain_template)
+
+        return checkpoint_info
+
+    @property
+    def sampler_chain(self):
+        nsteps = self._previous_iterations
+        return self.sampler.chain[:, :nsteps, :]
+
+    def checkpoint(self):
+        """ Writes a pickle file of the sampler to disk using dill """
+        logger.info("Checkpointing sampler to file {}"
+                    .format(self.checkpoint_info.sampler_file))
+        with open(self.checkpoint_info.sampler_file, 'wb') as f:
+            # Overwrites the stored sampler chain with one that is truncated
+            # to only the completed steps
+            self.sampler._chain = self.sampler_chain
+            pickle.dump(self._sampler, f)
+
+    def checkpoint_and_exit(self, signum, frame):
+        logger.info("Recieved signal {}".format(signum))
+        self.checkpoint()
+        sys.exit()
+
+    def _initialise_sampler(self):
         import emcee
-        tqdm = get_progress_bar()
-        sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
-        out_file, template = self.set_up_checkpoint()
+        self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
+
+    def _set_pos0_for_resume(self):
+        self.pos0 = self.sampler.chain[:, -1, :]
+
+    @property
+    def sampler(self):
+        """ Returns the ptemcee sampler object
+
+        If, alrady initialized, returns the stored _sampler value. Otherwise,
+        first checks if there is a pickle file from which to load. If there is
+        not, then initialize the sampler and set the initial random draw
+
+        """
+        if hasattr(self, '_sampler'):
+            pass
+        elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
+            with open(self.checkpoint_info.sampler_file, 'rb') as f:
+                self._sampler = pickle.load(f)
+            self._set_pos0_for_resume()
+        else:
+            self._initialise_sampler()
+            self._set_pos0()
+        return self._sampler
 
+    def write_chains_to_file(self, sample):
+        if self.prerelease:
+            points = np.hstack([sample.coords, sample.blobs])
+        else:
+            points = np.hstack([sample[0], np.array(sample[3])])
+        with open(self.checkpoint_info.chain_file, "a") as ff:
+            for ii, point in enumerate(points):
+                ff.write(self.checkpoint_info.chain_template.format(ii, *point))
+
+    def run_sampler(self):
+        tqdm = get_progress_bar()
         sampler_function_kwargs = self.sampler_function_kwargs
         iterations = sampler_function_kwargs.pop('iterations')
         iterations -= self._previous_iterations
 
+        print('pos0', self.pos0)
+        sampler_function_kwargs['p0'] = self.pos0
+
         for sample in tqdm(
-                sampler.sample(iterations=iterations, **sampler_function_kwargs),
+                self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
                 total=iterations):
-            if self.prerelease:
-                points = np.hstack([sample.coords, sample.blobs])
-            else:
-                points = np.hstack([sample[0], np.array(sample[3])])
-            with open(out_file, "a") as ff:
-                for ii, point in enumerate(points):
-                    ff.write(template.format(ii, *point))
+            self.write_chains_to_file(sample)
 
         self.result.sampler_output = np.nan
-        blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
+        blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
         log_likelihoods, log_priors = blobs_flat.T
-        if self._old_chain is not None:
-            chain = np.vstack([self._old_chain[:, :-2],
-                               sampler.chain.reshape((-1, self.ndim))])
-            log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
-            log_ps = np.hstack([self._old_chain[:, -1], log_priors])
-            self.nsteps = chain.shape[0] // self.nwalkers
-        else:
-            chain = sampler.chain.reshape((-1, self.ndim))
-            log_ls = log_likelihoods
-            log_ps = log_priors
+        chain = self.sampler.chain.reshape((-1, self.ndim))
+        log_ls = log_likelihoods
+        log_ps = log_priors
         self.calculate_autocorrelation(chain)
         self.print_nburn_logging_info()
         self.result.nburn = self.nburn
@@ -236,13 +316,27 @@ class Emcee(MCMCSampler):
         self.result.samples = chain[n_samples:, :]
         self.result.log_likelihood_evaluations = log_ls[n_samples:]
         self.result.log_prior_evaluations = log_ps[n_samples:]
-        self.result.walkers = sampler.chain
+        self.result.walkers = self.sampler.chain
         self.result.log_evidence = np.nan
         self.result.log_evidence_err = np.nan
         return self.result
 
+    @property
+    def _previous_iterations(self):
+        """ Returns the number of iterations that the sampler has saved
+
+        This is used when loading in a sampler from a pickle file to figure out
+        how much of the run has already been completed
+        """
+        return len(self.sampler.blobs)
+
     def _draw_pos0_from_prior(self):
-        return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
+        return np.array(
+            [self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
+
+    @property
+    def _pos0_shape(self):
+        return (self.nwalkers, self.ndim)
 
     def _set_pos0(self):
         if self.pos0 is not None:
@@ -250,9 +344,9 @@ class Emcee(MCMCSampler):
             if isinstance(self.pos0, DataFrame):
                 self.pos0 = self.pos0[self.search_parameter_keys].values
             elif type(self.pos0) in (list, np.ndarray):
-                self.pos0 = np.squeeze(self.kwargs['pos0'])
+                self.pos0 = np.squeeze(self.pos0)
 
-            if self.pos0.shape != (self.nwalkers, self.ndim):
+            if self.pos0.shape != self._pos0_shape:
                 raise ValueError(
                     'Input pos0 should be of shape ndim, nwalkers')
             logger.debug("Checking input pos0")
@@ -262,51 +356,6 @@ class Emcee(MCMCSampler):
             logger.debug("Generating initial walker positions from prior")
             self.pos0 = self._draw_pos0_from_prior()
 
-    @property
-    def _old_chain(self):
-        try:
-            old_chain = self.__old_chain
-            n = old_chain.shape[0]
-            idx = n - np.mod(n, self.nwalkers)
-            return old_chain[:idx, :]
-        except AttributeError:
-            return None
-
-    @_old_chain.setter
-    def _old_chain(self, old_chain):
-        self.__old_chain = old_chain
-
-    @property
-    def _previous_iterations(self):
-        if self._old_chain is None:
-            return 0
-        try:
-            return self._old_chain.shape[0] // self.nwalkers
-        except AttributeError:
-            logger.warning(
-                "Unable to calculate previous iterations from checkpoint,"
-                " defaulting to zero")
-            return 0
-
-    def load_old_chain(self, file_name=None):
-        if file_name is None:
-            out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
-            file_name = os.path.join(out_dir, 'chain.dat')
-        if os.path.isfile(file_name):
-            try:
-                old_chain = np.genfromtxt(file_name, skip_header=1)
-                self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
-                             for ii in range(self.nwalkers)]
-                self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
-                logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
-            except Exception:
-                logger.warning('Failed to resume. Corrupt checkpoint file {}.'
-                               .format(file_name))
-                self._set_pos0()
-        else:
-            logger.warning('Failed to resume. {} not found.'.format(file_name))
-            self._set_pos0()
-
     def lnpostfn(self, theta):
         log_prior = self.log_prior(theta)
         if np.isinf(log_prior):
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index 7f02f45e1c5c0aa7e576cd49644c0c66554aaa01..3325d3a55ede3c7e0f566706da9d304b0c5c2ef3 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -1,12 +1,8 @@
 from __future__ import absolute_import, division, print_function
 
-import os
-from collections import namedtuple
-
 import numpy as np
 
-from ..utils import (
-    logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
+from ..utils import logger, get_progress_bar
 from . import Emcee
 from .base_sampler import SamplerError
 
@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
         The number of temperatures used by ptemcee
 
     """
-    default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None,
-                          threads=1, pool=None, a=2.0, loglargs=[], logpargs=[],
-                          loglkwargs={}, logpkwargs={}, adaptation_lag=10000,
-                          adaptation_time=100, random=None, iterations=100,
-                          thin=1, storechain=True, adapt=True, swap_ratios=False,
-                          )
+    default_kwargs = dict(
+        ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
+        a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
+        adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
+        thin=1, storechain=True, adapt=True, swap_ratios=False)
 
     def __init__(self, likelihood, priors, outdir='outdir', label='label',
                  use_ratio=False, plot=False, skip_import_verification=False,
@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
                 if key not in self.sampler_function_kwargs}
 
     @property
-    def checkpoint_info(self):
-        out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
-        chain_file = os.path.join(out_dir, 'chain.dat')
-        last_pos_file = os.path.join(out_dir, 'last_pos.npy')
-
-        check_directory_exists_and_if_not_mkdir(out_dir)
-        if not os.path.isfile(chain_file):
-            with open(chain_file, "w") as ff:
-                ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
-                    '\t'.join(self.search_parameter_keys)))
-        template =\
-            '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
-
-        CheckpointInfo = namedtuple(
-            'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
-
-        checkpoint_info = CheckpointInfo(
-            last_pos_file=last_pos_file, chain_file=chain_file, template=template)
-
-        return checkpoint_info
+    def ntemps(self):
+        return self.kwargs['ntemps']
 
     def _draw_pos0_from_prior(self):
+        # for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
         return [[self.get_random_draw_from_prior()
                  for _ in range(self.nwalkers)]
                 for _ in range(self.kwargs['ntemps'])]
 
-    @property
-    def _old_chain(self):
-        try:
-            old_chain = self.__old_chain
-            n = old_chain.shape[0]
-            idx = n - np.mod(n, self.nwalkers)
-            return old_chain[:idx]
-        except AttributeError:
-            return None
-
-    @_old_chain.setter
-    def _old_chain(self, old_chain):
-        self.__old_chain = old_chain
+    def _set_pos0_for_resume(self):
+        self.pos0 = None
 
     @property
-    def stored_chain(self):
-        return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
+    def _previous_iterations(self):
+        """ Returns the number of iterations that the sampler has saved
 
-    @property
-    def stored_samples(self):
-        return self.stored_chain[self.search_parameter_keys]
+        This is used when loading in a sampler from a pickle file to figure out
+        how much of the run has already been completed
+        """
+        return self.sampler.time
 
     @property
-    def stored_loglike(self):
-        return self.stored_chain['log_l']
+    def sampler_chain(self):
+        nsteps = self._previous_iterations
+        return self.sampler.chain[:, :, :nsteps, :]
 
     @property
-    def stored_logprior(self):
-        return self.stored_chain['log_p']
-
-    def load_old_chain(self):
-        try:
-            last_pos = np.load(self.checkpoint_info.last_pos_file)
-            self.pos0 = last_pos
-            self._old_chain = self.stored_samples
-            logger.info(
-                'Resuming from {} with {} iterations'.format(
-                    self.checkpoint_info.chain_file,
-                    self._previous_iterations))
-        except Exception:
-            logger.info('Unable to resume')
-            self._set_pos0()
+    def _pos0_shape(self):
+        return (self.ntemps, self.nwalkers, self.ndim)
 
-    def run_sampler(self):
+    def _initialise_sampler(self):
         import ptemcee
-        tqdm = get_progress_bar()
-        sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
-                                  logp=self.log_prior, **self.sampler_init_kwargs)
-
-        if self.resume:
-            self.load_old_chain()
-        else:
-            self._set_pos0()
+        self._sampler = ptemcee.Sampler(
+            dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
+            **self.sampler_init_kwargs)
+
+    def print_tswap_acceptance_fraction(self):
+        logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
+            self.sampler.tswap_acceptance_fraction))
+
+    def write_chains_to_file(self, pos, loglike, logpost):
+        with open(self.checkpoint_info.chain_file, "a") as ff:
+            loglike = np.squeeze(loglike[0, :])
+            logprior = np.squeeze(logpost[0, :]) - loglike
+            for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
+                line = np.concatenate((point, [logl, logp]))
+                ff.write(self.checkpoint_info.chain_template.format(ii, *line))
 
+    def run_sampler(self):
+        tqdm = get_progress_bar()
         sampler_function_kwargs = self.sampler_function_kwargs
         iterations = sampler_function_kwargs.pop('iterations')
         iterations -= self._previous_iterations
 
+        # main iteration loop
         for pos, logpost, loglike in tqdm(
-                sampler.sample(self.pos0, iterations=iterations,
-                               **sampler_function_kwargs),
+                self.sampler.sample(self.pos0, iterations=iterations,
+                                    **sampler_function_kwargs),
                 total=iterations):
-            np.save(self.checkpoint_info.last_pos_file, pos)
-            with open(self.checkpoint_info.chain_file, "a") as ff:
-                loglike = np.squeeze(loglike[:1, :])
-                logprior = np.squeeze(logpost[:1, :]) - loglike
-                for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
-                    line = np.concatenate((point, [logl, logp]))
-                    ff.write(self.checkpoint_info.template.format(ii, *line))
-
-        self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
+            self.write_chains_to_file(pos, loglike, logpost)
+
+        self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
         self.result.sampler_output = np.nan
         self.print_nburn_logging_info()
+        self.print_tswap_acceptance_fraction()
+
         self.result.nburn = self.nburn
         if self.result.nburn > self.nsteps:
             raise SamplerError(
                 "The run has finished, but the chain is not burned in: "
                 "`nburn < nsteps`. Try increasing the number of steps.")
-        walkers = self.stored_samples.view((float, self.ndim))
-        walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
-        self.result.walkers = walkers
-        self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
+
+        self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
+            (-1, self.ndim))
+        self.result.walkers = self.sampler.chain[0, :, :, :]
+
         n_samples = self.nwalkers * self.nburn
         self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
         self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
-        self.result.betas = sampler.betas
+        self.result.betas = self.sampler.betas
         self.result.log_evidence, self.result.log_evidence_err =\
-            sampler.log_evidence_estimate(
-                sampler.loglikelihood, self.nburn / self.nsteps)
+            self.sampler.log_evidence_estimate(
+                self.sampler.loglikelihood, self.nburn / self.nsteps)
 
         return self.result
diff --git a/requirements.txt b/requirements.txt
index bb184b6a62790c07f5b232056a2d89b0291f398e..de58a6b16f36ea4614961dcdb0b54cce46167f8e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,3 +6,4 @@ matplotlib>=2.0
 scipy>=0.16
 pandas
 mock
+dill
diff --git a/setup.py b/setup.py
index 81575551bd6cfc9ae8a6682328739361c9141d8d..b535cc65eed11c33d1d93263b243cab2071c71af 100644
--- a/setup.py
+++ b/setup.py
@@ -79,6 +79,7 @@ setup(name='bilby',
           'future',
           'dynesty',
           'corner',
+          'dill',
           'numpy>=1.9',
           'matplotlib>=2.0',
           'pandas',