diff --git a/.dictionary.txt b/.dictionary.txt
new file mode 100644
index 0000000000000000000000000000000000000000..6b8adc596cdc293dfd4d158aebeb908adf705139
--- /dev/null
+++ b/.dictionary.txt
@@ -0,0 +1,2 @@
+hist
+livetime
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 584f8a6395f249eb1ea039ff8fff466e968b99ef..0d7ff4d82f43c8bdb696ab211490d216a91aa682 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -4,19 +4,27 @@ repos:
     hooks:
     -   id: check-merge-conflict # prevent committing files with merge conflicts
     -   id: flake8 # checks for flake8 errors
-#-   repo: https://github.com/codespell-project/codespell
-#    rev: v1.16.0
-#    hooks:
-#    -   id: codespell # Spellchecker
-#        args: [-L, nd, --skip, "*.ipynb,*.html", --ignore-words=.dictionary.txt]
-#        exclude: ^examples/tutorials/
-#-   repo: https://github.com/asottile/seed-isort-config
-#    rev: v1.3.0
-#    hooks:
-#    -   id: seed-isort-config
-#        args: [--application-directories, 'bilby/']
-#-   repo: https://github.com/pre-commit/mirrors-isort
-#    rev: v4.3.21
-#    hooks:
-#    -   id: isort # sort imports alphabetically and separates import into sections
-#        args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
+-   repo: https://github.com/psf/black
+    rev: 20.8b1
+    hooks:
+      - id: black
+        language_version: python3
+        files: ^bilby/bilby_mcmc/
+-   repo: https://github.com/codespell-project/codespell
+    rev: v1.16.0
+    hooks:
+    -   id: codespell
+        args: [--ignore-words=.dictionary.txt]
+        files: ^bilby/bilby_mcmc/
+-   repo: https://github.com/asottile/seed-isort-config
+    rev: v1.3.0
+    hooks:
+    -   id: seed-isort-config
+        args: [--application-directories, 'bilby/']
+        files: ^bilby/bilby_mcmc/
+-   repo: https://github.com/pre-commit/mirrors-isort
+    rev: v4.3.21
+    hooks:
+    -   id: isort # sort imports alphabetically and separates import into sections
+        args: [-w=88, -m=3, -tc, -sp=setup.cfg ]
+        files: ^bilby/bilby_mcmc/
diff --git a/bilby/bilby_mcmc/__init__.py b/bilby/bilby_mcmc/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbf49adb19d17adfc5553b0e5c06e46e7123c554
--- /dev/null
+++ b/bilby/bilby_mcmc/__init__.py
@@ -0,0 +1 @@
+from .sampler import Bilby_MCMC
diff --git a/bilby/bilby_mcmc/chain.py b/bilby/bilby_mcmc/chain.py
new file mode 100644
index 0000000000000000000000000000000000000000..43c47ef2ea0616ca0ecbe54525e0c5e36134a4cd
--- /dev/null
+++ b/bilby/bilby_mcmc/chain.py
@@ -0,0 +1,526 @@
+from distutils.version import LooseVersion
+
+import numpy as np
+import pandas as pd
+
+from ..core.sampler.base_sampler import SamplerError
+from ..core.utils import logger
+from .utils import LOGLKEY, LOGLLATEXKEY, LOGPKEY, LOGPLATEXKEY
+
+
+class Chain(object):
+    def __init__(
+        self,
+        initial_sample,
+        burn_in_nact=1,
+        thin_by_nact=1,
+        fixed_discard=0,
+        autocorr_c=5,
+        min_tau=1,
+        fixed_tau=None,
+        tau_window=None,
+        block_length=100000,
+    ):
+        """Object to store a single mcmc chain
+
+        Parameters
+        ----------
+        initial_sample: bilby.bilby_mcmc.chain.Sample
+            The starting point of the chain
+        burn_in_nact, thin_by_nact : int (1, 1)
+            The number of autocorrelation times (tau) to discard for burn-in
+            and the multiplicative factor to thin by (thin_by_nact < 1). I.e
+            burn_in_nact=10 and thin_by_nact=1 will discard 10*tau samples from
+            the start of the chain, then thin the final chain by a factor
+            of 1*tau (resulting in independent samples).
+        fixed_discard: int (0)
+            A fixed minimum number of samples to discard (can be used to
+            override the burn_in_nact if it is too small).
+        autocorr_c: float (5)
+            The step size of the window search used by emcee.autocorr when
+            estimating the autocorrelation time.
+        min_tau: int (1)
+            A minimum value for the autocorrelation time.
+        fixed_tau: int (None)
+            A fixed value for the autocorrelation (overrides the automated
+            autocorrelation time estimation). Used in testing.
+        tau_window: int (None)
+            Only calculate the autocorrelation time in a trailing window. If
+            None (default) this method is not used.
+        block_length: int
+            The incremental size to extend the array by when it runs out of
+            space.
+        """
+        self.autocorr_c = autocorr_c
+        self.min_tau = min_tau
+        self.burn_in_nact = burn_in_nact
+        self.thin_by_nact = thin_by_nact
+        self.block_length = block_length
+        self.fixed_discard = int(fixed_discard)
+        self.fixed_tau = fixed_tau
+        self.tau_window = tau_window
+
+        self.ndim = initial_sample.ndim
+        self.current_sample = initial_sample
+        self.keys = self.current_sample.keys
+        self.parameter_keys = self.current_sample.parameter_keys
+
+        # Initialize chain
+        self._chain_array = self._get_zero_chain_array()
+        self._chain_array_length = block_length
+        self.position = -1
+        self.max_log_likelihood = -np.inf
+        self.max_tau_dict = {}
+        self.converged = False
+        self.cached_tau_count = 0
+        self._minimum_index_proposal = 0
+        self._minimum_index_adapt = 0
+        self._last_minimum_index = (0, 0, "I")
+        self.last_full_tau_dict = {key: np.inf for key in self.parameter_keys}
+
+        # Append the initial sample
+        self.append(self.current_sample)
+
+    def _get_zero_chain_array(self):
+        return np.zeros((self.block_length, self.ndim + 2), dtype=np.float64)
+
+    def _extend_chain_array(self):
+        self._chain_array = np.concatenate(
+            (self._chain_array, self._get_zero_chain_array()), axis=0
+        )
+        self._chain_array_length = len(self._chain_array)
+
+    @property
+    def current_sample(self):
+        return self._current_sample.copy()
+
+    @current_sample.setter
+    def current_sample(self, current_sample):
+        self._current_sample = current_sample
+
+    def append(self, sample):
+        self.position += 1
+
+        # Extend the array if needed
+        if self.position >= self._chain_array_length:
+            self._extend_chain_array()
+
+        # Store the current sample and append to the array
+        self.current_sample = sample
+        self._chain_array[self.position] = sample.list
+
+        # Update the maximum log_likelihood
+        if sample[LOGLKEY] > self.max_log_likelihood:
+            self.max_log_likelihood = sample[LOGLKEY]
+
+    def __getitem__(self, index):
+        if index < 0:
+            index = index + self.position + 1
+
+        if index <= self.position:
+            values = self._chain_array[index]
+            return Sample({k: v for k, v in zip(self.keys, values)})
+        else:
+            raise SamplerError(f"Requested index {index} out of bounds")
+
+    def __setitem__(self, index, sample):
+        if index < 0:
+            index = index + self.position + 1
+
+        self._chain_array[index] = sample.list
+
+    def key_to_idx(self, key):
+        return self.keys.index(key)
+
+    def get_1d_array(self, key):
+        return self._chain_array[: 1 + self.position, self.key_to_idx(key)]
+
+    @property
+    def _random_idx(self):
+        mindex = self._last_minimum_index[1]
+        # Check if mindex exceeds current position by 10 ACT: if so use a random sample
+        # otherwise we draw only from the chain past the minimum_index
+        if np.isinf(self.tau_last) or self.position - mindex < 10 * self.tau_last:
+            mindex = 0
+        return np.random.randint(mindex, self.position + 1)
+
+    @property
+    def random_sample(self):
+        return self[self._random_idx]
+
+    @property
+    def fixed_discard(self):
+        return self._fixed_discard
+
+    @fixed_discard.setter
+    def fixed_discard(self, fixed_discard):
+        self._fixed_discard = int(fixed_discard)
+
+    @property
+    def minimum_index(self):
+        """This calculated a minimum index from which to discard samples
+
+        A number of methods are provided for the calculation. A subset are
+        switched off (by `if False` statements) for future development
+
+        """
+        position = self.position
+
+        # Return cached minimum index
+        last_minimum_index = self._last_minimum_index
+        if position == last_minimum_index[0]:
+            return int(last_minimum_index[1])
+
+        # If fixed discard is not yet reached, just return that
+        if position < self.fixed_discard:
+            self.minimum_index_method = "FD"
+            return self.fixed_discard
+
+        # Initialize list of minimum index methods with the fixed discard (FD)
+        minimum_index_list = [self.fixed_discard]
+        minimum_index_method_list = ["FD"]
+
+        # Calculate minimum index from tau
+        if self.tau_last < np.inf:
+            tau = self.tau_last
+        elif len(self.max_tau_dict) == 0:
+            # Bootstrap calculating tau when minimum index has not yet been calculated
+            tau = self._tau_for_full_chain
+        else:
+            tau = np.inf
+
+        if tau < np.inf:
+            minimum_index_list.append(self.burn_in_nact * tau)
+            minimum_index_method_list.append(f"{self.burn_in_nact}tau")
+
+        # Calculate points when log-posterior is within z std of the mean
+        if True:
+            zfactor = 1
+            N = 100
+            delta_lnP = zfactor * self.ndim / 2
+            logl = self.get_1d_array(LOGLKEY)
+            log_prior = self.get_1d_array(LOGPKEY)
+            log_posterior = logl + log_prior
+            max_posterior = np.max(log_posterior)
+
+            ave = pd.Series(log_posterior).rolling(window=N).mean().iloc[N - 1 :]
+            delta = max_posterior - ave
+            passes = ave[delta < delta_lnP]
+            if len(passes) > 0:
+                minimum_index_list.append(passes.index[0] + 1)
+                minimum_index_method_list.append(f"z{zfactor}")
+
+        # Add last minimum_index_method
+        if False:
+            minimum_index_list.append(last_minimum_index[1])
+            minimum_index_method_list.append(last_minimum_index[2])
+
+        # Minimum index set by proposals
+        minimum_index_list.append(self.minimum_index_proposal)
+        minimum_index_method_list.append("PR")
+
+        # Minimum index set by temperature adaptation
+        minimum_index_list.append(self.minimum_index_adapt)
+        minimum_index_method_list.append("AD")
+
+        # Calculate the maximum minimum index and associated method (reporting)
+        minimum_index = int(np.max(minimum_index_list))
+        minimum_index_method = minimum_index_method_list[np.argmax(minimum_index_list)]
+
+        # Cache the method
+        self._last_minimum_index = (position, minimum_index, minimum_index_method)
+        self.minimum_index_method = minimum_index_method
+
+        return minimum_index
+
+    @property
+    def minimum_index_proposal(self):
+        return self._minimum_index_proposal
+
+    @minimum_index_proposal.setter
+    def minimum_index_proposal(self, minimum_index_proposal):
+        if minimum_index_proposal > self._minimum_index_proposal:
+            self._minimum_index_proposal = minimum_index_proposal
+
+    @property
+    def minimum_index_adapt(self):
+        return self._minimum_index_adapt
+
+    @minimum_index_adapt.setter
+    def minimum_index_adapt(self, minimum_index_adapt):
+        if minimum_index_adapt > self._minimum_index_adapt:
+            self._minimum_index_adapt = minimum_index_adapt
+
+    @property
+    def tau(self):
+        """ The maximum ACT over all parameters """
+
+        if self.position in self.max_tau_dict:
+            # If we have the ACT at the current position, return it
+            return self.max_tau_dict[self.position]
+        elif (
+            self.tau_last < np.inf
+            and self.cached_tau_count < 50
+            and self.nsamples_last > 50
+        ):
+            # If we have a recent ACT return it
+            self.cached_tau_count += 1
+            return self.tau_last
+        else:
+            # Calculate the ACT
+            return self.tau_nocache
+
+    @property
+    def tau_nocache(self):
+        """ Calculate tau forcing a recalculation (no cached tau) """
+        tau = max(self.tau_dict.values())
+        self.max_tau_dict[self.position] = tau
+        self.cached_tau_count = 0
+        return tau
+
+    @property
+    def tau_last(self):
+        """ Return the last-calculated tau if it exists, else inf """
+        if len(self.max_tau_dict) > 0:
+            return list(self.max_tau_dict.values())[-1]
+        else:
+            return np.inf
+
+    @property
+    def _tau_for_full_chain(self):
+        """ The maximum ACT over all parameters """
+        return max(self._tau_dict_for_full_chain.values())
+
+    @property
+    def _tau_dict_for_full_chain(self):
+        return self._calculate_tau_dict(minimum_index=0)
+
+    @property
+    def tau_dict(self):
+        """ Calculate a dictionary of tau (ACT) for every parameter """
+        return self._calculate_tau_dict(self.minimum_index)
+
+    def _calculate_tau_dict(self, minimum_index):
+        """ Calculate a dictionary of tau (ACT) for every parameter """
+        logger.debug(f"Calculating tau_dict {self}")
+
+        # If there are too few samples to calculate tau
+        if (self.position - minimum_index) < 2 * self.autocorr_c:
+            return {key: np.inf for key in self.parameter_keys}
+
+        # Choose minimimum index for the ACT calculation
+        last_tau = self.tau_last
+        if self.tau_window is not None and last_tau < np.inf:
+            minimum_index_for_act = max(
+                minimum_index, int(self.position - self.tau_window * last_tau)
+            )
+        else:
+            minimum_index_for_act = minimum_index
+
+        # Calculate a dictionary of tau's for each parameter
+        taus = {}
+        for key in self.parameter_keys:
+            if self.fixed_tau is None:
+                x = self.get_1d_array(key)[minimum_index_for_act:]
+                tau = calculate_tau(x, self.autocorr_c)
+                taux = round(tau, 1)
+            else:
+                taux = self.fixed_tau
+            taus[key] = max(taux, self.min_tau)
+
+        # Cache the last tau dictionary for future use
+        self.last_full_tau_dict = taus
+
+        return taus
+
+    @property
+    def thin(self):
+        if np.isfinite(self.tau):
+            return np.max([1, int(self.thin_by_nact * self.tau)])
+        else:
+            return 1
+
+    @property
+    def nsamples(self):
+        nuseable_steps = self.position - self.minimum_index
+        return int(nuseable_steps / (self.thin_by_nact * self.tau))
+
+    @property
+    def nsamples_last(self):
+        nuseable_steps = self.position - self.minimum_index
+        return int(nuseable_steps / (self.thin_by_nact * self.tau_last))
+
+    @property
+    def samples(self):
+        samples = self._chain_array[self.minimum_index : self.position : self.thin]
+        return pd.DataFrame(samples, columns=self.keys)
+
+    def plot(self, outdir=".", label="label", priors=None, all_samples=None):
+        import matplotlib.pyplot as plt
+
+        fig, axes = plt.subplots(
+            nrows=self.ndim + 3, ncols=2, figsize=(8, 9 + 3 * (self.ndim))
+        )
+        scatter_kwargs = dict(
+            lw=0,
+            marker="o",
+        )
+        K = 1000
+
+        nburn = self.minimum_index
+        plot_setups = zip(
+            [0, nburn, nburn],
+            [nburn, self.position, self.position],
+            [1, 1, self.thin],  # Thin-by factor
+            ["tab:red", "tab:grey", "tab:blue"],  # Color
+            [0.5, 0.05, 0.5],  # Alpha
+            [1, 1, 1],  # Marker size
+        )
+
+        position_indexes = np.arange(self.position + 1)
+
+        # Plot the traceplots
+        for (start, stop, thin, color, alpha, ms) in plot_setups:
+            for ax, key in zip(axes[:, 0], self.keys):
+                xx = position_indexes[start:stop:thin] / K
+                yy = self.get_1d_array(key)[start:stop:thin]
+
+                # Downsample plots to max_pts: avoid memory issues
+                max_pts = 10000
+                while len(xx) > max_pts:
+                    xx = xx[::2]
+                    yy = yy[::2]
+
+                ax.plot(
+                    xx,
+                    yy,
+                    color=color,
+                    alpha=alpha,
+                    ms=ms,
+                    **scatter_kwargs,
+                )
+                ax.set_ylabel(self._get_plot_label_by_key(key, priors))
+                if key not in [LOGLKEY, LOGPKEY]:
+                    msg = r"$\tau=$" + f"{self.last_full_tau_dict[key]:0.1f}"
+                    ax.set_title(msg)
+
+        # Plot the histograms
+        for ax, key in zip(axes[:, 1], self.keys):
+            yy_all = all_samples[key]
+            ax.hist(yy_all, bins=50, alpha=0.6, density=True, color="k")
+
+            yy = self.get_1d_array(key)[nburn : self.position : self.thin]
+            ax.hist(yy, bins=50, alpha=0.8, density=True)
+
+            ax.set_xlabel(self._get_plot_label_by_key(key, priors))
+
+        # Add x-axes labels to the traceplots
+        axes[-1, 0].set_xlabel(r"Iteration $[\times 10^{3}]$")
+
+        # Plot the calculated ACT
+        ax = axes[-1, 0]
+        tausit = np.array(list(self.max_tau_dict.keys()) + [self.position]) / K
+        taus = list(self.max_tau_dict.values()) + [self.tau_last]
+        ax.plot(tausit, taus, color="C3")
+        ax.set(ylabel=r"Maximum $\tau$")
+
+        axes[-1, 1].set_axis_off()
+
+        filename = "{}/{}_checkpoint_trace.png".format(outdir, label)
+        msg = [
+            r"Maximum $\tau$" + f"={self.tau:0.1f} ",
+            r"$n_{\rm samples}=$" + f"{self.nsamples} ",
+        ]
+        if self.thin_by_nact != 1:
+            msg += [
+                r"$n_{\rm samples}^{\rm eff}=$"
+                + f"{int(self.nsamples * self.thin_by_nact)} "
+            ]
+        fig.suptitle(
+            "| ".join(msg),
+            y=1,
+        )
+        fig.tight_layout()
+        fig.savefig(filename, dpi=200)
+        plt.close(fig)
+
+    @staticmethod
+    def _get_plot_label_by_key(key, priors=None):
+        if priors is not None and key in priors:
+            return priors[key].latex_label
+        elif key == LOGLKEY:
+            return LOGLLATEXKEY
+        elif key == LOGPKEY:
+            return LOGPLATEXKEY
+        else:
+            return key
+
+
+class Sample(object):
+    def __init__(self, sample_dict):
+        """A single sample
+
+        Parameters
+        ----------
+        sample_dict: dict
+            A dictionary of the sample
+        """
+
+        self.sample_dict = sample_dict
+        self.keys = list(sample_dict.keys())
+        self.parameter_keys = [k for k in self.keys if k not in [LOGPKEY, LOGLKEY]]
+        self.ndim = len(self.parameter_keys)
+
+    def __getitem__(self, key):
+        return self.sample_dict[key]
+
+    def __setitem__(self, key, value):
+        self.sample_dict[key] = value
+        if key not in self.keys:
+            self.keys = list(self.sample_dict.keys())
+
+    @property
+    def list(self):
+        return list(self.sample_dict.values())
+
+    def __repr__(self):
+        return str(self.sample_dict)
+
+    @property
+    def parameter_only_dict(self):
+        return {key: self.sample_dict[key] for key in self.parameter_keys}
+
+    @property
+    def dict(self):
+        return {key: self.sample_dict[key] for key in self.keys}
+
+    def as_dict(self, keys=None):
+        sdict = self.dict
+        if keys is None:
+            return sdict
+        else:
+            return {key: sdict[key] for key in keys}
+
+    def __eq__(self, other_sample):
+        return self.list == other_sample.list
+
+    def copy(self):
+        return Sample(self.sample_dict.copy())
+
+
+def calculate_tau(x, autocorr_c=5):
+    import emcee
+
+    if LooseVersion(emcee.__version__) < LooseVersion("3"):
+        raise SamplerError("bilby-mcmc requires emcee > 3.0 for autocorr analysis")
+
+    if np.all(np.diff(x) == 0):
+        return np.inf
+    try:
+        # Hard code tol=1: we perform this check internally
+        tau = emcee.autocorr.integrated_time(x, c=autocorr_c, tol=1)[0]
+        if np.isnan(tau):
+            tau = np.inf
+        return tau
+    except emcee.autocorr.AutocorrError:
+        return np.inf
diff --git a/bilby/bilby_mcmc/flows.py b/bilby/bilby_mcmc/flows.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fbaf196b38859660307554e8bc371ce0b03edf8
--- /dev/null
+++ b/bilby/bilby_mcmc/flows.py
@@ -0,0 +1,100 @@
+import torch
+from nflows.distributions.normal import StandardNormal
+from nflows.flows.base import Flow
+from nflows.nn import nets as nets
+from nflows.transforms import (
+    CompositeTransform,
+    MaskedAffineAutoregressiveTransform,
+    RandomPermutation,
+)
+from nflows.transforms.coupling import (
+    AdditiveCouplingTransform,
+    AffineCouplingTransform,
+)
+from nflows.transforms.normalization import BatchNorm
+from torch.nn import functional as F
+
+# Turn off parallelism
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+class NVPFlow(Flow):
+    """A simplified version of Real NVP for 1-dim inputs.
+
+    This implementation uses 1-dim checkerboard masking but doesn't use
+    multi-scaling.
+    Reference:
+    > L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
+
+    This class has been modified from the example found at:
+    https://github.com/bayesiains/nflows/blob/master/nflows/flows/realnvp.py
+    """
+
+    def __init__(
+        self,
+        features,
+        hidden_features,
+        num_layers,
+        num_blocks_per_layer,
+        use_volume_preserving=False,
+        activation=F.relu,
+        dropout_probability=0.0,
+        batch_norm_within_layers=False,
+        batch_norm_between_layers=False,
+        random_permutation=True,
+    ):
+
+        if use_volume_preserving:
+            coupling_constructor = AdditiveCouplingTransform
+        else:
+            coupling_constructor = AffineCouplingTransform
+
+        mask = torch.ones(features)
+        mask[::2] = -1
+
+        def create_resnet(in_features, out_features):
+            return nets.ResidualNet(
+                in_features,
+                out_features,
+                hidden_features=hidden_features,
+                num_blocks=num_blocks_per_layer,
+                activation=activation,
+                dropout_probability=dropout_probability,
+                use_batch_norm=batch_norm_within_layers,
+            )
+
+        layers = []
+        for _ in range(num_layers):
+            transform = coupling_constructor(
+                mask=mask, transform_net_create_fn=create_resnet
+            )
+            layers.append(transform)
+            mask *= -1
+            if batch_norm_between_layers:
+                layers.append(BatchNorm(features=features))
+
+        if random_permutation:
+            layers.append(RandomPermutation(features=features))
+
+        super().__init__(
+            transform=CompositeTransform(layers),
+            distribution=StandardNormal([features]),
+        )
+
+
+class BasicFlow(Flow):
+    def __init__(self, features):
+        transform = CompositeTransform(
+            [
+                MaskedAffineAutoregressiveTransform(
+                    features=features, hidden_features=2 * features
+                ),
+                RandomPermutation(features=features),
+            ]
+        )
+        distribution = StandardNormal(shape=[features])
+        super().__init__(
+            transform=transform,
+            distribution=distribution,
+        )
diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d1d33bdf1b2e6d92802fe34be4c86e2976dd8a5
--- /dev/null
+++ b/bilby/bilby_mcmc/proposals.py
@@ -0,0 +1,1110 @@
+import importlib
+import time
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+from scipy.spatial.distance import jensenshannon
+from scipy.stats import gaussian_kde
+
+from ..core.prior import PriorDict
+from ..core.sampler.base_sampler import SamplerError
+from ..core.utils import logger, reflect
+from ..gw.source import PARAMETER_SETS
+
+
+class ProposalCycle(object):
+    def __init__(self, proposal_list):
+        self.proposal_list = proposal_list
+        self.weights = [prop.weight for prop in self.proposal_list]
+        self.normalized_weights = [w / sum(self.weights) for w in self.weights]
+        self.weighted_proposal_list = [
+            np.random.choice(self.proposal_list, p=self.normalized_weights)
+            for _ in range(10 * int(1 / min(self.normalized_weights)))
+        ]
+        self.nproposals = len(self.weighted_proposal_list)
+        self._position = 0
+
+    @property
+    def position(self):
+        return self._position
+
+    @position.setter
+    def position(self, position):
+        self._position = np.mod(position, self.nproposals)
+
+    def get_proposal(self):
+        prop = self.weighted_proposal_list[self._position]
+        self.position += 1
+        return prop
+
+    def __str__(self):
+        string = "ProposalCycle:\n"
+        for prop in self.proposal_list:
+            string += f"  {prop}\n"
+        return string
+
+
+class BaseProposal(object):
+    _accepted = 0
+    _rejected = 0
+    __metaclass__ = ABCMeta
+
+    def __init__(self, priors, weight=1, subset=None):
+        self._str_attrs = ["acceptance_ratio", "n"]
+
+        self.parameters = priors.non_fixed_keys
+        self.weight = weight
+        self.subset = subset
+
+        # Restrict to a subset
+        if self.subset is not None:
+            self.parameters = [p for p in self.parameters if p in subset]
+            self._str_attrs.append("parameters")
+
+        self.ndim = len(self.parameters)
+
+        self.prior_boundary_dict = {key: priors[key].boundary for key in priors}
+        self.prior_minimum_dict = {key: np.max(priors[key].minimum) for key in priors}
+        self.prior_maximum_dict = {key: np.min(priors[key].maximum) for key in priors}
+        self.prior_width_dict = {key: np.max(priors[key].width) for key in priors}
+
+    @property
+    def accepted(self):
+        return self._accepted
+
+    @accepted.setter
+    def accepted(self, accepted):
+        self._accepted = accepted
+
+    @property
+    def rejected(self):
+        return self._rejected
+
+    @rejected.setter
+    def rejected(self, rejected):
+        self._rejected = rejected
+
+    @property
+    def acceptance_ratio(self):
+        if self.n == 0:
+            return np.nan
+        else:
+            return self.accepted / self.n
+
+    @property
+    def n(self):
+        return self.accepted + self.rejected
+
+    def __str__(self):
+        msg = [f"{type(self).__name__}("]
+        for attr in self._str_attrs:
+            val = getattr(self, attr, "N/A")
+            if isinstance(val, (float, int)):
+                val = f"{val:1.2g}"
+            msg.append(f"{attr}:{val},")
+        return "".join(msg) + ")"
+
+    def apply_boundaries(self, point):
+        for key in self.parameters:
+            boundary = self.prior_boundary_dict[key]
+            if boundary is None:
+                continue
+            elif boundary == "periodic":
+                point[key] = self.apply_periodic_boundary(key, point[key])
+            elif boundary == "reflective":
+                point[key] = self.apply_reflective_boundary(key, point[key])
+            else:
+                raise SamplerError(f"Boundary {boundary} not implemented")
+        return point
+
+    def apply_periodic_boundary(self, key, val):
+        minimum = self.prior_minimum_dict[key]
+        width = self.prior_width_dict[key]
+        return minimum + np.mod(val - minimum, width)
+
+    def apply_reflective_boundary(self, key, val):
+        minimum = self.prior_minimum_dict[key]
+        width = self.prior_width_dict[key]
+        val_normalised = (val - minimum) / width
+        val_normalised_reflected = reflect(np.array(val_normalised))
+        return minimum + width * val_normalised_reflected
+
+    def __call__(self, chain):
+        sample, log_factor = self.propose(chain)
+        sample = self.apply_boundaries(sample)
+        return sample, log_factor
+
+    @abstractmethod
+    def propose(self, chain):
+        """Propose a new point
+
+        This method must be overwritten by implemented proposals. The propose
+        method is called by __call__, then boundaries applied, before returning
+        the proposed point.
+
+        Parameters
+        ----------
+        chain: bilby.core.sampler.bilby_mcmc.chain.Chain
+            The chain to use for the proposal
+
+        Returns
+        -------
+        proposal: bilby.core.sampler.bilby_mcmc.Sample
+            The proposed point
+        log_factor: float
+            The natural-log of the additional factor entering the acceptance
+            probability to ensure detailed balance. For symmetric proposals,
+            a value of 0 should be returned.
+        """
+        pass
+
+    @staticmethod
+    def check_dependencies(warn=True):
+        """Check the dependencies required to use the proposal
+
+        Parameters
+        ----------
+        warn: bool
+            If true, print a warning
+
+        Returns
+        -------
+        check: bool
+            If true, dependencies exist
+        """
+        return True
+
+
+class FixedGaussianProposal(BaseProposal):
+    """A proposal using a fixed non-correlated Gaussian distribution
+
+    Parameters
+    ----------
+    priors: bilby.core.prior.PriorDict
+        The set of priors
+    weight: float
+        Weighting factor
+    subset: list
+        A list of keys for which to restrict the proposal to (other parameters
+        will be kept fixed)
+    sigma: float
+        The scaling factor for proposals
+    """
+
+    def __init__(self, priors, weight=1, subset=None, sigma=0.01):
+        super(FixedGaussianProposal, self).__init__(priors, weight, subset)
+        self.sigmas = {}
+        for key in self.parameters:
+            if np.isinf(self.prior_width_dict[key]):
+                self.prior_width_dict[key] = 1
+            if isinstance(sigma, float):
+                self.sigmas[key] = sigma
+            elif isinstance(sigma, dict):
+                self.sigmas[key] = sigma[key]
+            else:
+                raise SamplerError("FixedGaussianProposal sigma not understood")
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        for key in self.parameters:
+            sigma = self.prior_width_dict[key] * self.sigmas[key]
+            sample[key] += sigma * np.random.randn()
+        log_factor = 0
+        return sample, log_factor
+
+
+class AdaptiveGaussianProposal(BaseProposal):
+    def __init__(
+        self,
+        priors,
+        weight=1,
+        subset=None,
+        sigma=1,
+        scale_init=1e0,
+        stop=1e5,
+        target_facc=0.234,
+    ):
+        super(AdaptiveGaussianProposal, self).__init__(priors, weight, subset)
+        self.sigmas = {}
+        for key in self.parameters:
+            if np.isinf(self.prior_width_dict[key]):
+                self.prior_width_dict[key] = 1
+            if isinstance(sigma, (float, int)):
+                self.sigmas[key] = sigma
+            elif isinstance(sigma, dict):
+                self.sigmas[key] = sigma[key]
+            else:
+                raise SamplerError("AdaptiveGaussianProposal sigma not understood")
+
+        self.target_facc = target_facc
+        self.scale = scale_init
+        self.stop = stop
+        self._str_attrs.append("scale")
+        self._last_accepted = 0
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        self.update_scale(chain)
+        if np.random.random() < 1e-3:
+            factor = 1e1
+        elif np.random.random() < 1e-4:
+            factor = 1e2
+        else:
+            factor = 1
+        for key in self.parameters:
+            sigma = factor * self.scale * self.prior_width_dict[key] * self.sigmas[key]
+            sample[key] += sigma * np.random.randn()
+        log_factor = 0
+        return sample, log_factor
+
+    def update_scale(self, chain):
+        """
+        The adaptation of the scale follows (35)/(36) of https://arxiv.org/abs/1409.7215
+        """
+        if 0 < self.n < self.stop:
+            s_gamma = (self.stop / self.n) ** 0.2 - 1
+            if self.accepted > self._last_accepted:
+                self.scale += s_gamma * (1 - self.target_facc) / 100
+            else:
+                self.scale -= s_gamma * self.target_facc / 100
+            self._last_accepted = self.accepted
+            self.scale = max(self.scale, 1 / self.stop)
+
+
+class DifferentialEvolutionProposal(BaseProposal):
+    """A proposal using Differential Evolution
+
+    Parameters
+    ----------
+    priors: bilby.core.prior.PriorDict
+        The set of priors
+    weight: float
+        Weighting factor
+    subset: list
+        A list of keys for which to restrict the proposal to (other parameters
+        will be kept fixed)
+    mode_hopping_frac: float
+        The fraction of proposals which use 'mode hopping'
+    """
+
+    def __init__(self, priors, weight=1, subset=None, mode_hopping_frac=0.5):
+        super(DifferentialEvolutionProposal, self).__init__(priors, weight, subset)
+        self.mode_hopping_frac = mode_hopping_frac
+
+    def propose(self, chain):
+        theta = chain.current_sample
+        theta1 = chain.random_sample
+        theta2 = chain.random_sample
+        if np.random.rand() > self.mode_hopping_frac:
+            gamma = 1
+        else:
+            # Base jump size
+            gamma = np.random.normal(0, 2.38 / np.sqrt(2 * self.ndim))
+            # Scale uniformly in log between 0.1 and 10 times
+            gamma *= np.exp(np.log(0.1) + np.log(100.0) * np.random.rand())
+
+        for key in self.parameters:
+            theta[key] += gamma * (theta2[key] - theta1[key])
+
+        log_factor = 0
+        return theta, log_factor
+
+
+class UniformProposal(BaseProposal):
+    """A proposal using uniform draws from the prior support
+
+    Parameters
+    ----------
+    priors: bilby.core.prior.PriorDict
+        The set of priors
+    weight: float
+        Weighting factor
+    subset: list
+        A list of keys for which to restrict the proposal to (other parameters
+        will be kept fixed)
+    """
+
+    def __init__(self, priors, weight=1, subset=None):
+        super(UniformProposal, self).__init__(priors, weight, subset)
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        for key in self.parameters:
+            sample[key] = np.random.uniform(
+                self.prior_minimum_dict[key], self.prior_maximum_dict[key]
+            )
+        log_factor = 0
+        return sample, log_factor
+
+
+class PriorProposal(BaseProposal):
+    """A proposal using draws from the prior distribution
+
+    Note: for priors which use interpolation, this proposal can be problematic
+    as the proposal gets pickled in multiprocessing. Either, use serial
+    processing (npool=1) or fall back to a UniformProposal.
+
+    Parameters
+    ----------
+    priors: bilby.core.prior.PriorDict
+        The set of priors
+    weight: float
+        Weighting factor
+    subset: list
+        A list of keys for which to restrict the proposal to (other parameters
+        will be kept fixed)
+    """
+
+    def __init__(self, priors, weight=1, subset=None):
+        super(PriorProposal, self).__init__(priors, weight, subset)
+        self.priors = PriorDict({key: priors[key] for key in self.parameters})
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        lnp_theta = self.priors.ln_prob(sample.as_dict(self.parameters))
+        prior_sample = self.priors.sample()
+        for key in self.parameters:
+            sample[key] = prior_sample[key]
+        lnp_thetaprime = self.priors.ln_prob(sample.as_dict(self.parameters))
+        log_factor = lnp_theta - lnp_thetaprime
+        return sample, log_factor
+
+
+_density_estimate_doc = """ A proposal using draws from a {estimator} fit to the chain
+
+Parameters
+----------
+priors: bilby.core.prior.PriorDict
+    The set of priors
+weight: float
+    Weighting factor
+subset: list
+    A list of keys for which to restrict the proposal to (other parameters
+    will be kept fixed)
+first_fit: int
+    The number of steps to take before first fitting the KDE
+fit_multiplier: int
+    The multiplier for the next fit
+nsamples_for_density: int
+    The number of samples to use when fitting the KDE
+fallback: bilby.core.sampler.bilby_mcmc.proposal.BaseProposal
+    A proposal to use before first training
+scale_fits: int
+    A scaling factor for both the initial and subsequent updates
+"""
+
+
+class DensityEstimateProposal(BaseProposal):
+    def __init__(
+        self,
+        priors,
+        weight=1,
+        subset=None,
+        first_fit=1000,
+        fit_multiplier=10,
+        nsamples_for_density=1000,
+        fallback=AdaptiveGaussianProposal,
+        scale_fits=1,
+    ):
+        super(DensityEstimateProposal, self).__init__(priors, weight, subset)
+        self.nsamples_for_density = nsamples_for_density
+        self.fallback = fallback(priors, weight, subset)
+        self.fit_multiplier = fit_multiplier * scale_fits
+
+        # Counters
+        self.steps_since_refit = 0
+        self.next_refit_time = first_fit * scale_fits
+        self.density = None
+        self.trained = False
+        self._str_attrs.append("trained")
+
+    density_name = None
+    __doc__ = _density_estimate_doc.format(estimator=density_name)
+
+    def _fit(self, dataset):
+        raise NotImplementedError
+
+    def _evaluate(self, point):
+        raise NotImplementedError
+
+    def _sample(self, nsamples=None):
+        raise NotImplementedError
+
+    def refit(self, chain):
+        current_density = self.density
+        start = time.time()
+
+        # Draw two (possibly overlapping) data sets for training and verification
+        dataset = []
+        verification_dataset = []
+        nsamples_for_density = min(chain.position, self.nsamples_for_density)
+        for _ in range(nsamples_for_density):
+            s = chain.random_sample
+            dataset.append([s[key] for key in self.parameters])
+            s = chain.random_sample
+            verification_dataset.append([s[key] for key in self.parameters])
+
+        # Fit the density
+        self.density = self._fit(np.array(dataset).T)
+
+        # Print a log message
+        took = time.time() - start
+        logger.info(
+            f"{self.density_name} construction at {self.steps_since_refit} finished"
+            f" for length {chain.position} chain, took {took:0.2f}s."
+            f" Current accept-ratio={self.acceptance_ratio:0.2f}"
+        )
+
+        # Reset counters for next training
+        self.steps_since_refit = 0
+        self.next_refit_time *= self.fit_multiplier
+
+        # Verify training hasn't overconstrained
+        new_draws = np.atleast_2d(self._sample(1000))
+        verification_dataset = np.array(verification_dataset)
+        fail_parameters = []
+        for ii, key in enumerate(self.parameters):
+            std_draws = np.std(new_draws[:, ii])
+            std_verification = np.std(verification_dataset[:, ii])
+            if std_draws < 0.1 * std_verification:
+                fail_parameters.append(key)
+
+        if len(fail_parameters) > 0:
+            logger.info(
+                f"{self.density_name} construction failed verification and is discarded"
+            )
+            self.density = current_density
+        else:
+            self.trained = True
+
+    def propose(self, chain):
+        self.steps_since_refit += 1
+
+        # Check if we refit
+        testA = self.steps_since_refit >= self.next_refit_time
+        if testA:
+            self.refit(chain)
+
+        # If KDE is yet to be fitted, use the fallback
+        if self.trained is False:
+            return self.fallback.propose(chain)
+
+        # Grab the current sample and it's probability under the KDE
+        theta = chain.current_sample
+        ln_p_theta = self._evaluate(list(theta.as_dict(self.parameters).values()))
+
+        # Sample and update theta
+        new_sample = self._sample(1)
+        for key, val in zip(self.parameters, new_sample):
+            theta[key] = val
+
+        # Calculate the probability of the new sample and the KDE
+        ln_p_thetaprime = self._evaluate(list(theta.as_dict(self.parameters).values()))
+
+        # Calculate Q(theta|theta') / Q(theta'|theta)
+        log_factor = ln_p_theta - ln_p_thetaprime
+
+        return theta, log_factor
+
+
+class KDEProposal(DensityEstimateProposal):
+    density_name = "Gaussian KDE"
+    __doc__ = _density_estimate_doc.format(estimator=density_name)
+
+    def _fit(self, dataset):
+        return gaussian_kde(dataset)
+
+    def _evaluate(self, point):
+        return self.density.logpdf(point)[0]
+
+    def _sample(self, nsamples=None):
+        return np.atleast_1d(np.squeeze(self.density.resample(nsamples)))
+
+
+class GMMProposal(DensityEstimateProposal):
+    density_name = "Gaussian Mixture Model"
+    __doc__ = _density_estimate_doc.format(estimator=density_name)
+
+    def _fit(self, dataset):
+        from sklearn.mixture import GaussianMixture
+
+        density = GaussianMixture(n_components=10)
+        density.fit(dataset.T)
+        return density
+
+    def _evaluate(self, point):
+        return np.squeeze(self.density.score_samples(np.atleast_2d(point)))
+
+    def _sample(self, nsamples=None):
+        return np.squeeze(self.density.sample(n_samples=nsamples)[0])
+
+    def check_dependencies(warn=True):
+        if importlib.util.find_spec("sklearn") is None:
+            if warn:
+                logger.warning(
+                    "Unable to utilise GMMProposal as sklearn is not installed"
+                )
+            return False
+        else:
+            return True
+
+
+class NormalizingFlowProposal(DensityEstimateProposal):
+    density_name = "Normalizing Flow"
+    __doc__ = _density_estimate_doc.format(estimator=density_name) + (
+        """
+        js_factor: float
+            The factor to use in determining the max-JS factor to terminate
+            training.
+        max_training_epochs: int
+            The maximum bumber of traning steps to take
+        """
+    )
+
+    def __init__(
+        self,
+        priors,
+        weight=1,
+        subset=None,
+        first_fit=1000,
+        fit_multiplier=10,
+        max_training_epochs=1000,
+        scale_fits=1,
+        nsamples_for_density=1000,
+        js_factor=10,
+        fallback=AdaptiveGaussianProposal,
+    ):
+        super(NormalizingFlowProposal, self).__init__(
+            priors=priors,
+            weight=weight,
+            subset=subset,
+            first_fit=first_fit,
+            fit_multiplier=fit_multiplier,
+            nsamples_for_density=nsamples_for_density,
+            fallback=fallback,
+            scale_fits=scale_fits,
+        )
+        self.setup_flow()
+        self.setup_optimizer()
+
+        self.max_training_epochs = max_training_epochs
+        self.js_factor = js_factor
+
+    def setup_flow(self):
+        if self.ndim < 3:
+            self.setup_basic_flow()
+        else:
+            self.setup_NVP_flow()
+
+    def setup_NVP_flow(self):
+        from .flows import NVPFlow
+
+        self.flow = NVPFlow(
+            features=self.ndim,
+            hidden_features=self.ndim * 2,
+            num_layers=2,
+            num_blocks_per_layer=2,
+            batch_norm_between_layers=True,
+            batch_norm_within_layers=True,
+        )
+
+    def setup_basic_flow(self):
+        from .flows import BasicFlow
+
+        self.flow = BasicFlow(features=self.ndim)
+
+    def setup_optimizer(self):
+        from torch import optim
+
+        self.optimizer = optim.Adam(self.flow.parameters())
+
+    def get_training_data(self, chain):
+        training_data = []
+        nsamples_for_density = min(chain.position, self.nsamples_for_density)
+        for _ in range(nsamples_for_density):
+            s = chain.random_sample
+            training_data.append([s[key] for key in self.parameters])
+        return training_data
+
+    def _calculate_js(self, validation_samples, training_samples_draw):
+        # Calculate the maximum JS between the validation and draw
+        max_js = 0
+        for i in range(self.ndim):
+            A = validation_samples[:, i]
+            B = training_samples_draw[:, i]
+            xmin = np.min([np.min(A), np.min(B)])
+            xmax = np.min([np.max(A), np.max(B)])
+            xval = np.linspace(xmin, xmax, 100)
+            Apdf = gaussian_kde(A)(xval)
+            Bpdf = gaussian_kde(B)(xval)
+            js = jensenshannon(Apdf, Bpdf)
+            max_js = max(max_js, js)
+        return np.power(max_js, 2)
+
+    def train(self, chain):
+        logger.info("Starting NF training")
+
+        import torch
+
+        start = time.time()
+
+        training_samples = np.array(self.get_training_data(chain))
+        validation_samples = np.array(self.get_training_data(chain))
+
+        training_tensor = torch.tensor(training_samples, dtype=torch.float32)
+
+        max_js_threshold = self.js_factor / self.nsamples_for_density
+
+        for epoch in range(1, self.max_training_epochs + 1):
+            self.optimizer.zero_grad()
+            loss = -self.flow.log_prob(inputs=training_tensor).mean()
+            loss.backward()
+            self.optimizer.step()
+
+            # Draw from the current flow
+            self.flow.eval()
+            training_samples_draw = (
+                self.flow.sample(self.nsamples_for_density).detach().numpy()
+            )
+            self.flow.train()
+
+            if np.mod(epoch, 10) == 0:
+                max_js_bits = self._calculate_js(
+                    validation_samples, training_samples_draw
+                )
+                if max_js_bits < max_js_threshold:
+                    logger.info(
+                        f"Training complete after {epoch} steps, "
+                        f"max_js_bits={max_js_bits:0.5f}<{max_js_threshold}"
+                    )
+                    break
+
+        took = time.time() - start
+        logger.info(
+            f"Flow training step ({self.steps_since_refit}) finished"
+            f" for length {chain.position} chain, took {took:0.2f}s."
+            f" Current accept-ratio={self.acceptance_ratio:0.2f}"
+        )
+        self.steps_since_refit = 0
+        self.next_refit_time *= self.fit_multiplier
+        self.trained = True
+
+    def propose(self, chain):
+        import torch
+
+        self.steps_since_refit += 1
+        theta = chain.current_sample
+
+        # Check if we retrain the NF
+        testA = self.steps_since_refit >= self.next_refit_time
+        if testA:
+            self.train(chain)
+
+        if self.trained is False:
+            return self.fallback.propose(chain)
+
+        self.flow.eval()
+        theta_prime_T = self.flow.sample(1)
+
+        logp_theta_prime = self.flow.log_prob(theta_prime_T).detach().numpy()[0]
+        theta_T = torch.tensor(
+            np.atleast_2d([theta[key] for key in self.parameters]), dtype=torch.float32
+        )
+        logp_theta = self.flow.log_prob(theta_T).detach().numpy()[0]
+        log_factor = logp_theta - logp_theta_prime
+
+        flow_sample_values = np.atleast_1d(np.squeeze(theta_prime_T.detach().numpy()))
+        for key, val in zip(self.parameters, flow_sample_values):
+            theta[key] = val
+
+        return theta, float(log_factor)
+
+    def check_dependencies(warn=True):
+        if importlib.util.find_spec("nflows") is None:
+            if warn:
+                logger.warning(
+                    "Unable to utilise NormalizingFlowProposal as nflows is not installed"
+                )
+            return False
+        else:
+            return True
+
+
+class FixedJumpProposal(BaseProposal):
+    def __init__(self, priors, jumps=1, subset=None, weight=1, scale=1e-4):
+        super(FixedJumpProposal, self).__init__(priors, weight, subset)
+        self.scale = scale
+        if isinstance(jumps, (int, float)):
+            self.jumps = {key: jumps for key in self.parameters}
+        elif isinstance(jumps, dict):
+            self.jumps = jumps
+        else:
+            raise SamplerError("jumps not understood")
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        for key, jump in self.jumps.items():
+            sign = np.random.randint(2) * 2 - 1
+            sample[key] += sign * jump + self.epsilon * self.prior_width_dict[key]
+        log_factor = 0
+        return sample, log_factor
+
+    @property
+    def epsilon(self):
+        return self.scale * np.random.normal()
+
+
+class BaseGravitationalWaveTransientProposal(BaseProposal):
+    def __init__(self, priors, weight=1):
+        super(BaseGravitationalWaveTransientProposal, self).__init__(
+            priors, weight=weight
+        )
+        if "phase" in priors:
+            self.phase_key = "phase"
+        elif "delta_phase" in priors:
+            self.phase_key = "delta_phase"
+        else:
+            self.phase_key = None
+
+    def get_cos_theta_jn(self, sample):
+        if "cos_theta_jn" in sample.parameter_keys:
+            cos_theta_jn = sample["cos_theta_jn"]
+        elif "theta_jn" in sample.parameter_keys:
+            cos_theta_jn = np.cos(sample["theta_jn"])
+        else:
+            raise SamplerError()
+        return cos_theta_jn
+
+    def get_phase(self, sample):
+        if "phase" in sample.parameter_keys:
+            return sample["phase"]
+        elif "delta_phase" in sample.parameter_keys:
+            cos_theta_jn = self.get_cos_theta_jn(sample)
+            delta_phase = sample["delta_phase"]
+            psi = sample["psi"]
+            phase = np.mod(delta_phase - np.sign(cos_theta_jn) * psi, 2 * np.pi)
+        else:
+            raise SamplerError()
+        return phase
+
+    def get_delta_phase(self, phase, sample):
+        cos_theta_jn = self.get_cos_theta_jn(sample)
+        psi = sample["psi"]
+        delta_phase = phase + np.sign(cos_theta_jn) * psi
+        return delta_phase
+
+
+class CorrelatedPolarisationPhaseJump(BaseGravitationalWaveTransientProposal):
+    def __init__(self, priors, weight=1):
+        super(CorrelatedPolarisationPhaseJump, self).__init__(priors, weight=weight)
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        phase = self.get_phase(sample)
+
+        alpha = sample["psi"] + phase
+        beta = sample["psi"] - phase
+
+        draw = np.random.random()
+        if draw < 0.5:
+            alpha = 3.0 * np.pi * np.random.random()
+        else:
+            beta = 3.0 * np.pi * np.random.random() - 2 * np.pi
+
+        # Update
+        sample["psi"] = (alpha + beta) * 0.5
+        phase = (alpha - beta) * 0.5
+
+        if self.phase_key == "delta_phase":
+            sample["delta_phase"] = self.get_delta_phase(phase, sample)
+        else:
+            sample["phase"] = phase
+
+        log_factor = 0
+        return sample, log_factor
+
+
+class PhaseReversalProposal(BaseGravitationalWaveTransientProposal):
+    def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1):
+        super(PhaseReversalProposal, self).__init__(priors, weight)
+        self.fuzz = fuzz
+        self.fuzz_sigma = fuzz_sigma
+        if self.phase_key is None:
+            raise SamplerError(
+                f"{type(self).__name__} initialised without a phase prior"
+            )
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        phase = sample[self.phase_key]
+        sample[self.phase_key] = np.mod(phase + np.pi + self.epsilon, 2 * np.pi)
+        log_factor = 0
+        return sample, log_factor
+
+    @property
+    def epsilon(self):
+        if self.fuzz:
+            return np.random.normal(0, self.fuzz_sigma)
+        else:
+            return 0
+
+
+class PolarisationReversalProposal(PhaseReversalProposal):
+    def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-3):
+        super(PolarisationReversalProposal, self).__init__(
+            priors, weight, fuzz, fuzz_sigma
+        )
+        self.fuzz = fuzz
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        psi = sample["psi"]
+        sample["psi"] = np.mod(psi + np.pi / 2 + self.epsilon, np.pi)
+        log_factor = 0
+        return sample, log_factor
+
+
+class PhasePolarisationReversalProposal(PhaseReversalProposal):
+    def __init__(self, priors, weight=1, fuzz=True, fuzz_sigma=1e-1):
+        super(PhasePolarisationReversalProposal, self).__init__(
+            priors, weight, fuzz, fuzz_sigma
+        )
+        self.fuzz = fuzz
+
+    def propose(self, chain):
+        sample = chain.current_sample
+        sample[self.phase_key] = np.mod(
+            sample[self.phase_key] + np.pi + self.epsilon, 2 * np.pi
+        )
+        sample["psi"] = np.mod(sample["psi"] + np.pi / 2 + self.epsilon, np.pi)
+        log_factor = 0
+        return sample, log_factor
+
+
+class StretchProposal(BaseProposal):
+    """The Goodman & Weare (2010) Stretch proposal for an MCMC chain
+
+    Implementation of the Stretch proposal using a sample drawn from the chain.
+    We assume the form of g(z) from Equation (9) of [1].
+
+    References
+    ----------
+    [1] Goodman & Weare (2010)
+        https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract
+
+    """
+
+    def __init__(self, priors, weight=1, subset=None, scale=2):
+        super(StretchProposal, self).__init__(priors, weight, subset)
+        self.scale = scale
+
+    def propose(self, chain):
+        sample = chain.current_sample
+
+        # Draw a random sample
+        rand = chain.random_sample
+
+        return _stretch_move(sample, rand, self.scale, self.ndim, self.parameters)
+
+
+def _stretch_move(sample, complement, scale, ndim, parameters):
+    # Draw z
+    u = np.random.rand()
+    z = (u * (scale - 1) + 1) ** 2 / scale
+
+    log_factor = (ndim - 1) * np.log(z)
+
+    for key in parameters:
+        sample[key] = complement[key] + (sample[key] - complement[key]) * z
+
+    return sample, log_factor
+
+
+class EnsembleProposal(BaseProposal):
+    """ Base EnsembleProposal class for ensemble-based swap proposals """
+
+    def __init__(self, priors, weight=1):
+        super(EnsembleProposal, self).__init__(priors, weight)
+
+    def __call__(self, chain, chain_complement):
+        sample, log_factor = self.propose(chain, chain_complement)
+        sample = self.apply_boundaries(sample)
+        return sample, log_factor
+
+
+class EnsembleStretch(EnsembleProposal):
+    """The Goodman & Weare (2010) Stretch proposal for an Ensemble
+
+    Implementation of the Stretch proposal using a sample drawn from complement.
+    We assume the form of g(z) from Equation (9) of [1].
+
+    References
+    ----------
+    [1] Goodman & Weare (2010)
+        https://ui.adsabs.harvard.edu/abs/2010CAMCS...5...65G/abstract
+
+    """
+
+    def __init__(self, priors, weight=1, scale=2):
+        super(EnsembleStretch, self).__init__(priors, weight)
+        self.scale = scale
+
+    def propose(self, chain, chain_complement):
+        sample = chain.current_sample
+        completement = chain_complement[
+            np.random.randint(len(chain_complement))
+        ].current_sample
+        return _stretch_move(
+            sample, completement, self.scale, self.ndim, self.parameters
+        )
+
+
+def get_default_ensemble_proposal_cycle(priors):
+    return ProposalCycle([EnsembleStretch(priors)])
+
+
+def get_proposal_cycle(string, priors, L1steps=1, warn=True):
+    big_weight = 10
+    small_weight = 5
+    tiny_weight = 0.1
+
+    if "gwA" in string:
+        # Parameters for learning proposals
+        learning_kwargs = dict(
+            first_fit=1000, nsamples_for_density=10000, fit_multiplier=2
+        )
+
+        plist = [
+            AdaptiveGaussianProposal(priors, weight=small_weight),
+            DifferentialEvolutionProposal(priors, weight=small_weight),
+        ]
+
+        if GMMProposal.check_dependencies(warn=warn) is False:
+            raise SamplerError(
+                "the gwA proposal_cycle required the GMMProposal dependencies"
+            )
+
+        if priors.intrinsic:
+            intrinsic = PARAMETER_SETS["intrinsic"]
+            plist += [
+                AdaptiveGaussianProposal(priors, weight=big_weight, subset=intrinsic),
+                DifferentialEvolutionProposal(
+                    priors, weight=big_weight, subset=intrinsic
+                ),
+                KDEProposal(
+                    priors, weight=big_weight, subset=intrinsic, **learning_kwargs
+                ),
+                GMMProposal(
+                    priors, weight=big_weight, subset=intrinsic, **learning_kwargs
+                ),
+            ]
+
+        if priors.extrinsic:
+            extrinsic = PARAMETER_SETS["extrinsic"]
+            plist += [
+                AdaptiveGaussianProposal(priors, weight=small_weight, subset=extrinsic),
+                DifferentialEvolutionProposal(
+                    priors, weight=big_weight, subset=extrinsic
+                ),
+                KDEProposal(
+                    priors, weight=big_weight, subset=extrinsic, **learning_kwargs
+                ),
+                GMMProposal(
+                    priors, weight=big_weight, subset=extrinsic, **learning_kwargs
+                ),
+            ]
+
+        if priors.mass:
+            mass = PARAMETER_SETS["mass"]
+            plist += [
+                DifferentialEvolutionProposal(priors, weight=small_weight, subset=mass),
+                GMMProposal(
+                    priors, weight=small_weight, subset=mass, **learning_kwargs
+                ),
+            ]
+
+        if priors.spin:
+            spin = PARAMETER_SETS["spin"]
+            plist += [
+                DifferentialEvolutionProposal(priors, weight=small_weight, subset=spin),
+                GMMProposal(
+                    priors, weight=small_weight, subset=spin, **learning_kwargs
+                ),
+            ]
+        if priors.precession:
+            measured_spin = ["chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane"]
+            plist += [
+                AdaptiveGaussianProposal(
+                    priors, weight=small_weight, subset=measured_spin
+                ),
+            ]
+
+        if priors.mass and priors.spin:
+            primary_spin_and_q = PARAMETER_SETS["primary_spin_and_q"]
+            plist += [
+                DifferentialEvolutionProposal(
+                    priors, weight=small_weight, subset=primary_spin_and_q
+                ),
+            ]
+
+        if getattr(priors, "tidal", False):
+            tidal = PARAMETER_SETS["tidal"]
+            plist += [
+                DifferentialEvolutionProposal(
+                    priors, weight=small_weight, subset=tidal
+                ),
+                PriorProposal(priors, weight=small_weight, subset=tidal),
+            ]
+        if priors.phase:
+            plist += [
+                PhaseReversalProposal(priors, weight=tiny_weight),
+            ]
+        if priors.phase and "psi" in priors.non_fixed_keys:
+            plist += [
+                CorrelatedPolarisationPhaseJump(priors, weight=tiny_weight),
+                PhasePolarisationReversalProposal(priors, weight=tiny_weight),
+            ]
+        for key in ["time_jitter", "psi", "phi_12", "tilt_2", "lambda_1", "lambda_2"]:
+            if key in priors.non_fixed_keys:
+                plist.append(PriorProposal(priors, subset=[key], weight=tiny_weight))
+        if "chi_1_in_plane" in priors and "chi_2_in_plane" in priors:
+            in_plane = ["chi_1_in_plane", "chi_2_in_plane", "phi_12"]
+            plist.append(UniformProposal(priors, subset=in_plane, weight=tiny_weight))
+        if any("recalib_" in key for key in priors):
+            calibration = [key for key in priors if "recalib_" in key]
+            plist.append(PriorProposal(priors, subset=calibration, weight=small_weight))
+    else:
+        plist = [
+            AdaptiveGaussianProposal(priors, weight=big_weight),
+            DifferentialEvolutionProposal(priors, weight=big_weight),
+            UniformProposal(priors, weight=tiny_weight),
+            KDEProposal(priors, weight=big_weight, scale_fits=L1steps),
+        ]
+        if GMMProposal.check_dependencies(warn=warn):
+            plist.append(GMMProposal(priors, weight=big_weight, scale_fits=L1steps))
+        if NormalizingFlowProposal.check_dependencies(warn=warn):
+            plist.append(
+                NormalizingFlowProposal(priors, weight=big_weight, scale_fits=L1steps)
+            )
+
+    plist = remove_proposals_using_string(plist, string)
+    return ProposalCycle(plist)
+
+
+def remove_proposals_using_string(plist, string):
+    mapping = dict(
+        DE=DifferentialEvolutionProposal,
+        AG=AdaptiveGaussianProposal,
+        ST=StretchProposal,
+        FG=FixedGaussianProposal,
+        NF=NormalizingFlowProposal,
+        KD=KDEProposal,
+        GM=GMMProposal,
+        PR=PriorProposal,
+        UN=UniformProposal,
+    )
+
+    for element in string.split("no")[1:]:
+        if element in mapping:
+            plist = [p for p in plist if isinstance(p, mapping[element]) is False]
+    return plist
diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cad693ae9c1879388141d99d49fc282a3e0487a
--- /dev/null
+++ b/bilby/bilby_mcmc/sampler.py
@@ -0,0 +1,1348 @@
+import datetime
+import os
+import signal
+import time
+from collections import Counter
+
+import numpy as np
+import pandas as pd
+
+from ..core.result import rejection_sample
+from ..core.sampler.base_sampler import MCMCSampler, ResumeError, SamplerError
+from ..core.utils import check_directory_exists_and_if_not_mkdir, logger, safe_file_dump
+from . import proposals
+from .chain import Chain, Sample
+from .utils import LOGLKEY, LOGPKEY, ConvergenceInputs, ParallelTemperingInputs
+
+
+class Bilby_MCMC(MCMCSampler):
+    """The built-in Bilby MCMC sampler
+
+    Parameters
+    ----------
+    likelihood: likelihood.Likelihood
+        A  object with a log_l method
+    priors: bilby.core.prior.PriorDict, dict
+        Priors to be used in the search.
+        This has attributes for each parameter to be sampled.
+    outdir: str, optional
+        Name of the output directory
+    label: str, optional
+        Naming scheme of the output files
+    use_ratio: bool, optional
+        Switch to set whether or not you want to use the log-likelihood ratio
+        or just the log-likelihood
+    skip_import_verification: bool
+        Skips the check if the sampler is installed if true. This is
+        only advisable for testing environments
+    check_point_plot: bool
+        If true, create plots at the check point
+    check_point_delta_t: float
+        The time in seconds afterwhich to checkpoint (defaults to 30 minutes)
+    diagnostic: bool
+        If true, create deep-diagnostic plots used for checking convergence
+        problems.
+    resume: bool
+        If true, resume from any existing check point files
+    exit_code: int
+        The code on which to raise if exiting
+
+    Sampling Parameters
+    -------------------
+    nsamples: int (1000)
+        The number of samples to draw
+    nensemble: int (1)
+        The number of ensemble-chains to run (with periodic communication)
+    pt_ensemble: bool (False)
+        If true, each run a parallel-tempered set of chains for each
+        ensemble-chain (in which case the total number of chains is
+        nensemble * ntemps). Else, only the zero-ensemble chain is run with a
+        parallel-tempering (in which case the total number of chains is
+        nensemble + ntemps - 1).
+    ntemps: int (1)
+        The number of parallel-tempered chains to run
+    Tmax: float, (None)
+        If given, the maximum temperature to set the initial temperate-ladder
+    Tmax_from_SNR: float (20)
+        (Alternative to Tmax): The SNR to estimate an appropriate Tmax from.
+    initial_betas: list (None)
+        (Alternative to Tmax and Tmax_from_SNR): If given, an initial choice of
+        the inverse temperature ladder.
+    pt_rejection_sample: bool (False)
+        If true, use rejection sampling to draw samples from the pt-chains.
+    adapt, adapt_t0, adapt_nu: bool, float, float (True, 100, 10)
+        Whether to use adaptation and the adaptation parameters.
+        See arXiv:1501.05823 for a description of adapt_t0 and adapt_nu.
+    burn_in_nact, thin_by_nact, fixed_discard: float, float, float (10, 1, 0)
+        The number of auto-correlation times to discard for burn-in and to
+        thin by. The fixed_discard is the number of steps discarded before
+        automatic autocorrelation time analysis begins.
+    autocorr_c: float (5)
+        The step-size for the window search. See emcee.autocorr.integrated_time
+        for additional details.
+    L1steps: int
+        The number of internal steps to take. Improves the scaling performance
+        of multiprocessing. Note, all ACTs are calculated based on the saved
+        steps. So, the total ACT (or number of steps) is L1steps * tau
+        (or L1steps * position).
+    L2steps: int
+        The number of steps to take before swapping between parallel-tempered
+        and ensemble chains.
+    npool: int
+        The number of multiprocessing cores to use. For efficiency, this must be
+        matched to an integer number of the total number of chains.
+    printdt: float
+        Print an update on the progress every printdt s. Note, each print
+        requires an evaluation of the ACT so short print times are unwise.
+    min_tau: 1
+        The minimum allowed ACT. Can be used to force a larger ACT.
+    proposal_cycle: str, bilby.core.sampler.bilby_mcmc.proposals.ProposalCycle
+        Either a string pointing to one of the built-in proposal cycles or,
+        a proposal cycle.
+    stop_after_convergence:
+        If running with parallel-tempered chains. Stop updating the chains once
+        they have congerged. After this time, random samples will be drawn at
+        swap time.
+    fixed_tau: int
+        A fixed value for the ACT: used for testing purposes.
+    tau_window: int, None
+        Using tau', a previous estimates of tau, calculate the new tau using
+        the last tau_window * tau' steps. If None, the entire chain is used.
+    evidence_method: str, [stepping_stone, thermodynamic]
+        The evidence calculation method to use. Defaults to stepping_stone, but
+        the results of all available methods are stored in the ln_z_dict.
+
+    """
+
+    default_kwargs = dict(
+        nsamples=1000,
+        nensemble=1,
+        pt_ensemble=False,
+        ntemps=1,
+        Tmax=None,
+        Tmax_from_SNR=20,
+        initial_betas=None,
+        adapt=True,
+        adapt_t0=100,
+        adapt_nu=10,
+        pt_rejection_sample=False,
+        burn_in_nact=10,
+        thin_by_nact=1,
+        fixed_discard=0,
+        autocorr_c=5,
+        L1steps=100,
+        L2steps=3,
+        npool=1,
+        printdt=60,
+        min_tau=1,
+        proposal_cycle="default",
+        stop_after_convergence=False,
+        fixed_tau=None,
+        tau_window=None,
+        evidence_method="stepping_stone",
+    )
+
+    def __init__(
+        self,
+        likelihood,
+        priors,
+        outdir="outdir",
+        label="label",
+        use_ratio=False,
+        skip_import_verification=True,
+        check_point_plot=True,
+        check_point_delta_t=1800,
+        diagnostic=False,
+        resume=True,
+        exit_code=130,
+        **kwargs,
+    ):
+
+        super(Bilby_MCMC, self).__init__(
+            likelihood=likelihood,
+            priors=priors,
+            outdir=outdir,
+            label=label,
+            use_ratio=use_ratio,
+            skip_import_verification=skip_import_verification,
+            exit_code=exit_code,
+            **kwargs,
+        )
+
+        self.check_point_plot = check_point_plot
+        self.diagnostic = diagnostic
+        self.kwargs["target_nsamples"] = self.kwargs["nsamples"]
+        self.npool = self.kwargs["npool"]
+        self.L1steps = self.kwargs["L1steps"]
+        self.L2steps = self.kwargs["L2steps"]
+        self.pt_inputs = ParallelTemperingInputs(
+            **{key: self.kwargs[key] for key in ParallelTemperingInputs._fields}
+        )
+        self.convergence_inputs = ConvergenceInputs(
+            **{key: self.kwargs[key] for key in ConvergenceInputs._fields}
+        )
+        self.proposal_cycle = self.kwargs["proposal_cycle"]
+        self.pt_rejection_sample = self.kwargs["pt_rejection_sample"]
+        self.evidence_method = self.kwargs["evidence_method"]
+
+        self.printdt = self.kwargs["printdt"]
+        check_directory_exists_and_if_not_mkdir(self.outdir)
+        self.resume = resume
+        self.check_point_delta_t = check_point_delta_t
+        self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label)
+
+        self.verify_configuration()
+
+        try:
+            signal.signal(signal.SIGTERM, self.write_current_state_and_exit)
+            signal.signal(signal.SIGINT, self.write_current_state_and_exit)
+            signal.signal(signal.SIGALRM, self.write_current_state_and_exit)
+        except AttributeError:
+            logger.debug(
+                "Setting signal attributes unavailable on this system. "
+                "This is likely the case if you are running on a Windows machine"
+                " and is no further concern."
+            )
+
+    def verify_configuration(self):
+        if self.convergence_inputs.burn_in_nact / self.kwargs["target_nsamples"] > 0.1:
+            logger.warning("Burn-in inefficiency fraction greater than 10%")
+
+    def _translate_kwargs(self, kwargs):
+        if "printdt" not in kwargs:
+            for equiv in ["print_dt", "print_update"]:
+                if equiv in kwargs:
+                    kwargs["printdt"] = kwargs.pop(equiv)
+        if "npool" not in kwargs:
+            for equiv in self.npool_equiv_kwargs:
+                if equiv in kwargs:
+                    kwargs["npool"] = kwargs.pop(equiv)
+
+    @property
+    def target_nsamples(self):
+        return self.kwargs["target_nsamples"]
+
+    def run_sampler(self):
+        self._setup_pool()
+        self.setup_chain_set()
+        self.start_time = datetime.datetime.now()
+        self.draw()
+        self._close_pool()
+        self.check_point(ignore_time=True)
+
+        self.result = self.add_data_to_result(
+            result=self.result,
+            ptsampler=self.ptsampler,
+            outdir=self.outdir,
+            label=self.label,
+            make_plots=self.check_point_plot,
+        )
+
+        return self.result
+
+    @staticmethod
+    def add_data_to_result(result, ptsampler, outdir, label, make_plots):
+        result.samples = ptsampler.samples
+        result.log_likelihood_evaluations = result.samples[LOGLKEY]
+        result.log_prior_evaluations = result.samples[LOGPKEY]
+        ptsampler.compute_evidence(
+            outdir=outdir,
+            label=label,
+            make_plots=make_plots,
+        )
+        result.log_evidence = ptsampler.ln_z
+        result.log_evidence_err = ptsampler.ln_z_err
+        result.sampling_time = datetime.timedelta(seconds=ptsampler.sampling_time)
+        result.meta_data["bilby_mcmc"] = dict(
+            tau=ptsampler.tau,
+            convergence_inputs=ptsampler.convergence_inputs._asdict(),
+            pt_inputs=ptsampler.pt_inputs._asdict(),
+            total_steps=ptsampler.position,
+            nsamples=ptsampler.nsamples,
+        )
+        if ptsampler.pool is not None:
+            npool = ptsampler.pool._processes
+        else:
+            npool = 1
+        result.meta_data["run_statistics"] = dict(
+            nlikelihood=ptsampler.position * ptsampler.L1steps * ptsampler._nsamplers,
+            neffsamples=ptsampler.nsamples * ptsampler.convergence_inputs.thin_by_nact,
+            sampling_time_s=result.sampling_time.seconds,
+            ncores=npool,
+        )
+
+        return result
+
+    def setup_chain_set(self):
+        if os.path.isfile(self.resume_file) and self.resume is True:
+            self.read_current_state()
+            self.ptsampler.pool = self.pool
+        else:
+            self.init_ptsampler()
+
+    def init_ptsampler(self):
+
+        logger.info(f"Initializing BilbyPTMCMCSampler with:\n{self.get_setup_string()}")
+        self.ptsampler = BilbyPTMCMCSampler(
+            convergence_inputs=self.convergence_inputs,
+            pt_inputs=self.pt_inputs,
+            proposal_cycle=self.proposal_cycle,
+            pt_rejection_sample=self.pt_rejection_sample,
+            pool=self.pool,
+            use_ratio=self.use_ratio,
+            evidence_method=self.evidence_method,
+        )
+
+    def get_setup_string(self):
+        string = (
+            f"  Convergence settings: {self.convergence_inputs}\n"
+            f"  Parallel-tempering settings: {self.pt_inputs}\n"
+            f"  proposal_cycle: {self.proposal_cycle}\n"
+            f"  pt_rejection_sample: {self.pt_rejection_sample}"
+        )
+        return string
+
+    def draw(self):
+        self._steps_since_last_print = 0
+        self._time_since_last_print = 0
+        logger.info(f"Drawing {self.target_nsamples} samples")
+        logger.info(f"Checkpoint every {self.check_point_delta_t}s")
+        logger.info(f"Print update every {self.printdt}s")
+
+        while True:
+            t0 = datetime.datetime.now()
+            self.ptsampler.step_all_chains()
+            dt = (datetime.datetime.now() - t0).total_seconds()
+            self.ptsampler.sampling_time += dt
+            self._time_since_last_print += dt
+            self._steps_since_last_print += self.ptsampler.L1steps
+
+            if self._time_since_last_print > self.printdt:
+                tp0 = datetime.datetime.now()
+                self.print_progress()
+                tp = datetime.datetime.now()
+                ppt_frac = (tp - tp0).total_seconds() / self._time_since_last_print
+                if ppt_frac > 0.01:
+                    logger.warning(
+                        f"Non-negligible print progress time (ppt_frac={ppt_frac:0.2f})"
+                    )
+                self._steps_since_last_print = 0
+                self._time_since_last_print = 0
+
+            self.check_point()
+
+            if self.ptsampler.nsamples_last >= self.target_nsamples:
+                # Perform a second check without cached values
+                if self.ptsampler.nsamples_nocache >= self.target_nsamples:
+                    logger.info("Reached convergence: exiting sampling")
+                    break
+
+    def check_point(self, ignore_time=False):
+        tS = (datetime.datetime.now() - self.start_time).total_seconds()
+        if os.path.isfile(self.resume_file):
+            tR = time.time() - os.path.getmtime(self.resume_file)
+        else:
+            tR = np.inf
+
+        if ignore_time or np.min([tS, tR]) > self.check_point_delta_t:
+            logger.info("Checkpoint start")
+            self.write_current_state()
+            self.print_long_progress()
+            logger.info("Checkpoint finished")
+
+    def _remove_checkpoint(self):
+        """Remove checkpointed state"""
+        if os.path.isfile(self.resume_file):
+            os.remove(self.resume_file)
+
+    def read_current_state(self):
+        import dill
+
+        with open(self.resume_file, "rb") as file:
+            self.ptsampler = dill.load(file)
+            if self.ptsampler.pt_inputs != self.pt_inputs:
+                msg = (
+                    f"pt_inputs has changed: {self.ptsampler.pt_inputs} "
+                    f"-> {self.pt_inputs}"
+                )
+                raise ResumeError(msg)
+            self.ptsampler.set_convergence_inputs(self.convergence_inputs)
+            self.ptsampler.proposal_cycle = self.proposal_cycle
+            self.ptsampler.pt_rejection_sample = self.pt_rejection_sample
+
+        logger.info(
+            f"Loaded resume file {self.resume_file} "
+            f"with {self.ptsampler.position} steps "
+            f"setup:\n{self.get_setup_string()}"
+        )
+
+    def write_current_state_and_exit(self, signum=None, frame=None):
+        """
+        Make sure that if a pool of jobs is running only the parent tries to
+        checkpoint and exit. Only the parent has a 'pool' attribute.
+        """
+        if self.npool == 1 or getattr(self, "pool", None) is not None:
+            if signum == 14:
+                logger.info(
+                    "Run interrupted by alarm signal {}: checkpoint and exit on {}".format(
+                        signum, self.exit_code
+                    )
+                )
+            else:
+                logger.info(
+                    "Run interrupted by signal {}: checkpoint and exit on {}".format(
+                        signum, self.exit_code
+                    )
+                )
+            self.write_current_state()
+            self._close_pool()
+            os._exit(self.exit_code)
+
+    def write_current_state(self):
+        import dill
+
+        logger.debug("Check point")
+        check_directory_exists_and_if_not_mkdir(self.outdir)
+
+        _pool = self.ptsampler.pool
+        self.ptsampler.pool = None
+        if dill.pickles(self.ptsampler):
+            safe_file_dump(self.ptsampler, self.resume_file, dill)
+            logger.info("Written checkpoint file {}".format(self.resume_file))
+        else:
+            logger.warning(
+                "Cannot write pickle resume file! "
+                "Job will not resume if interrupted."
+            )
+        self.ptsampler.pool = _pool
+
+    def print_long_progress(self):
+        self.print_per_proposal()
+        self.print_tau_dict()
+        if self.ptsampler.ntemps > 1:
+            self.print_pt_acceptance()
+        if self.ptsampler.nensemble > 1:
+            self.print_ensemble_acceptance()
+        if self.check_point_plot:
+            self.plot_progress(
+                self.ptsampler, self.label, self.outdir, self.priors, self.diagnostic
+            )
+            self.ptsampler.compute_evidence(
+                outdir=self.outdir, label=self.label, make_plots=True
+            )
+
+    def print_ensemble_acceptance(self):
+        logger.info(f"Ensemble swaps = {self.ptsampler.swap_counter['ensemble']}")
+        logger.info(self.ptsampler.ensemble_proposal_cycle)
+
+    def print_progress(self):
+        position = self.ptsampler.position
+
+        # Total sampling time
+        sampling_time = datetime.timedelta(seconds=self.ptsampler.sampling_time)
+        time = str(sampling_time).split(".")[0]
+
+        # Time for last evaluation set
+        time_per_eval_ms = (
+            1000 * self._time_since_last_print / self._steps_since_last_print
+        )
+
+        # Pull out progress summary
+        tau = self.ptsampler.tau
+        nsamples = self.ptsampler.nsamples
+        minimum_index = self.ptsampler.primary_sampler.chain.minimum_index
+        method = self.ptsampler.primary_sampler.chain.minimum_index_method
+        mindex_str = f"{minimum_index:0.2e}({method})"
+        alpha = self.ptsampler.primary_sampler.acceptance_ratio
+        maxl = self.ptsampler.primary_sampler.chain.max_log_likelihood
+
+        nlikelihood = position * self.L1steps * self.ptsampler._nsamplers
+        eff = 100 * nsamples / nlikelihood
+
+        # Estimated time til finish (ETF)
+        if tau < np.inf:
+            remaining_samples = self.target_nsamples - nsamples
+            remaining_evals = (
+                remaining_samples
+                * self.convergence_inputs.thin_by_nact
+                * tau
+                * self.L1steps
+            )
+            remaining_time_s = time_per_eval_ms * 1e-3 * remaining_evals
+            remaining_time_dt = datetime.timedelta(seconds=remaining_time_s)
+            if remaining_samples > 0:
+                remaining_time = str(remaining_time_dt).split(".")[0]
+            else:
+                remaining_time = "0"
+        else:
+            remaining_time = "-"
+
+        msg = (
+            f"{position:0.2e}|{time}|{mindex_str}|t={tau:0.0f}|"
+            f"n={nsamples:0.0f}|a={alpha:0.2f}|e={eff:0.1e}%|"
+            f"{time_per_eval_ms:0.2f}ms/ev|maxl={maxl:0.2f}|"
+            f"ETF={remaining_time}"
+        )
+
+        if self.pt_rejection_sample:
+            count = self.ptsampler.rejection_sampling_count
+            rse = 100 * count / nsamples
+            msg += f"|rse={rse:0.2f}%"
+
+        print(msg, flush=True)
+
+    def print_per_proposal(self):
+        logger.info("Zero-temperature proposals:")
+        for prop in self.ptsampler[0].proposal_cycle.proposal_list:
+            logger.info(prop)
+
+    def print_pt_acceptance(self):
+        logger.info(f"Temperature swaps = {self.ptsampler.swap_counter['temperature']}")
+        for column in self.ptsampler.sampler_list_of_tempered_lists:
+            for ii, sampler in enumerate(column):
+                total = sampler.pt_accepted + sampler.pt_rejected
+                beta = sampler.beta
+                if total > 0:
+                    ratio = f"{sampler.pt_accepted / total:0.2f}"
+                else:
+                    ratio = "-"
+                logger.info(
+                    f"Temp:{ii}<->{ii+1}|"
+                    f"beta={beta:0.4g}|"
+                    f"hot-samp={sampler.nsamples}|"
+                    f"swap={ratio}|"
+                    f"conv={sampler.chain.converged}|"
+                )
+
+    def print_tau_dict(self):
+        msg = f"Current taus={self.ptsampler.primary_sampler.chain.tau_dict}"
+        logger.info(msg)
+
+    @staticmethod
+    def plot_progress(ptsampler, label, outdir, priors, diagnostic=False):
+        logger.info("Creating diagnostic plots")
+        for ii, row in ptsampler.sampler_dictionary.items():
+            for jj, sampler in enumerate(row):
+                plot_label = f"{label}_E{sampler.Eindex}_T{sampler.Tindex}"
+                if diagnostic is True or sampler.beta == 1:
+                    sampler.chain.plot(
+                        outdir=outdir,
+                        label=plot_label,
+                        priors=priors,
+                        all_samples=ptsampler.samples,
+                    )
+
+    def _setup_pool(self):
+        if self.npool > 1:
+            logger.info(f"Setting up multiproccesing pool with {self.npool} processes")
+            import multiprocessing
+
+            self.pool = multiprocessing.Pool(
+                processes=self.npool,
+                initializer=_initialize_global_variables,
+                initargs=(
+                    self.likelihood,
+                    self.priors,
+                    self._search_parameter_keys,
+                    self.use_ratio,
+                ),
+            )
+        else:
+            self.pool = None
+
+        _initialize_global_variables(
+            likelihood=self.likelihood,
+            priors=self.priors,
+            search_parameter_keys=self._search_parameter_keys,
+            use_ratio=self.use_ratio,
+        )
+
+    def _close_pool(self):
+        if getattr(self, "pool", None) is not None:
+            logger.info("Starting to close worker pool.")
+            self.pool.close()
+            self.pool.join()
+            self.pool = None
+            logger.info("Finished closing worker pool.")
+
+
+class BilbyPTMCMCSampler(object):
+    def __init__(
+        self,
+        convergence_inputs,
+        pt_inputs,
+        proposal_cycle,
+        pt_rejection_sample,
+        pool,
+        use_ratio,
+        evidence_method,
+    ):
+
+        self.set_pt_inputs(pt_inputs)
+        self.use_ratio = use_ratio
+        self.setup_sampler_dictionary(convergence_inputs, proposal_cycle)
+        self.set_convergence_inputs(convergence_inputs)
+        self.pt_rejection_sample = pt_rejection_sample
+        self.pool = pool
+        self.evidence_method = evidence_method
+
+        # Initialize counters
+        self.swap_counter = Counter()
+        self.swap_counter["temperature"] = 0
+        self.swap_counter["L2-temperature"] = 0
+        self.swap_counter["ensemble"] = 0
+        self.swap_counter["L2-ensemble"] = int(self.L2steps / 2) + 1
+
+        self._nsamples_dict = {}
+        self.ensemble_proposal_cycle = proposals.get_default_ensemble_proposal_cycle(
+            _priors
+        )
+        self.sampling_time = 0
+        self.ln_z_dict = dict()
+        self.ln_z_err_dict = dict()
+
+    def get_initial_betas(self):
+        pt_inputs = self.pt_inputs
+        if self.ntemps == 1:
+            betas = np.array([1])
+        elif pt_inputs.initial_betas is not None:
+            betas = np.array(pt_inputs.initial_betas)
+        elif pt_inputs.Tmax is not None:
+            betas = np.logspace(0, -np.log10(pt_inputs.Tmax), pt_inputs.ntemps)
+        elif pt_inputs.Tmax_from_SNR is not None:
+            ndim = len(_priors.non_fixed_keys)
+            target_hot_likelihood = ndim / 2
+            Tmax = pt_inputs.Tmax_from_SNR ** 2 / (2 * target_hot_likelihood)
+            betas = np.logspace(0, -np.log10(Tmax), pt_inputs.ntemps)
+        else:
+            raise SamplerError("Unable to set temperature ladder from inputs")
+
+        if len(betas) != self.ntemps:
+            raise SamplerError("Temperatures do not match ntemps")
+
+        return betas
+
+    def setup_sampler_dictionary(self, convergence_inputs, proposal_cycle):
+
+        betas = self.get_initial_betas()
+        logger.info(
+            f"Initializing BilbyPTMCMCSampler with:"
+            f"ntemps={self.ntemps},"
+            f"nensemble={self.nensemble},"
+            f"pt_ensemble={self.pt_ensemble},"
+            f"initial_betas={betas}\n"
+        )
+        self.sampler_dictionary = dict()
+        for Tindex, beta in enumerate(betas):
+            if beta == 1 or self.pt_ensemble:
+                n = self.nensemble
+            else:
+                n = 1
+            temp_sampler_list = [
+                BilbyMCMCSampler(
+                    beta=beta,
+                    Tindex=Tindex,
+                    Eindex=Eindex,
+                    convergence_inputs=convergence_inputs,
+                    proposal_cycle=proposal_cycle,
+                    use_ratio=self.use_ratio,
+                )
+                for Eindex in range(n)
+            ]
+            self.sampler_dictionary[Tindex] = temp_sampler_list
+
+        # Store data
+        self._nsamplers = len(self.sampler_list)
+
+    @property
+    def sampler_list(self):
+        """ A list of all individual samplers """
+        return [s for item in self.sampler_dictionary.values() for s in item]
+
+    @sampler_list.setter
+    def sampler_list(self, sampler_list):
+        for sampler in sampler_list:
+            self.sampler_dictionary[sampler.Tindex][sampler.Eindex] = sampler
+
+    def sampler_list_by_column(self, column):
+        return [row[column] for row in self.sampler_dictionary.values()]
+
+    @property
+    def sampler_list_of_tempered_lists(self):
+        if self.pt_ensemble:
+            return [self.sampler_list_by_column(ii) for ii in range(self.nensemble)]
+        else:
+            return [self.sampler_list_by_column(0)]
+
+    @property
+    def tempered_sampler_list(self):
+        return [s for s in self.sampler_list if s.beta < 1]
+
+    @property
+    def zerotemp_sampler_list(self):
+        return [s for s in self.sampler_list if s.beta == 1]
+
+    @property
+    def primary_sampler(self):
+        return self.sampler_dictionary[0][0]
+
+    def set_pt_inputs(self, pt_inputs):
+        logger.info(f"Setting parallel tempering inputs={pt_inputs}")
+        self.pt_inputs = pt_inputs
+
+        # Pull out only what is needed
+        self.ntemps = pt_inputs.ntemps
+        self.nensemble = pt_inputs.nensemble
+        self.pt_ensemble = pt_inputs.pt_ensemble
+        self.adapt = pt_inputs.adapt
+        self.adapt_t0 = pt_inputs.adapt_t0
+        self.adapt_nu = pt_inputs.adapt_nu
+
+    def set_convergence_inputs(self, convergence_inputs):
+        logger.info(f"Setting convergence_inputs={convergence_inputs}")
+        self.convergence_inputs = convergence_inputs
+        self.L1steps = convergence_inputs.L1steps
+        self.L2steps = convergence_inputs.L2steps
+        for sampler in self.sampler_list:
+            sampler.set_convergence_inputs(convergence_inputs)
+
+    @property
+    def tau(self):
+        return self.primary_sampler.chain.tau
+
+    @property
+    def minimum_index(self):
+        return self.primary_sampler.chain.minimum_index
+
+    @property
+    def nsamples(self):
+        pos = self.primary_sampler.chain.position
+        if hasattr(self, "_nsamples_dict") is False:
+            self._nsamples_dict = {}
+        if pos in self._nsamples_dict:
+            return self._nsamples_dict[pos]
+        logger.debug(f"Calculating nsamples at {pos}")
+        self._nsamples_dict[pos] = self._calculate_nsamples()
+        return self._nsamples_dict[pos]
+
+    @property
+    def nsamples_last(self):
+        if len(self._nsamples_dict) > 0:
+            return list(self._nsamples_dict.values())[-1]
+        else:
+            return 0
+
+    @property
+    def nsamples_nocache(self):
+        for sampler in self.sampler_list:
+            sampler.chain.tau_nocache
+        pos = self.primary_sampler.chain.position
+        self._nsamples_dict[pos] = self._calculate_nsamples()
+        return self._nsamples_dict[pos]
+
+    def _calculate_nsamples(self):
+        nsamples_list = []
+        for sampler in self.zerotemp_sampler_list:
+            nsamples_list.append(sampler.nsamples)
+        if self.pt_rejection_sample:
+            for samp in self.sampler_list[1:]:
+                nsamples_list.append(
+                    len(samp.rejection_sample_zero_temperature_samples())
+                )
+        return sum(nsamples_list)
+
+    @property
+    def samples(self):
+        sample_list = []
+        for sampler in self.zerotemp_sampler_list:
+            sample_list.append(sampler.samples)
+        if self.pt_rejection_sample:
+            for sampler in self.tempered_sampler_list:
+                sample_list.append(sampler.samples)
+        return pd.concat(sample_list)
+
+    @property
+    def position(self):
+        return self.primary_sampler.chain.position
+
+    @property
+    def evaluations(self):
+        return int(self.position * len(self.sampler_list))
+
+    def __getitem__(self, index):
+        return self.sampler_list[index]
+
+    def step_all_chains(self):
+        if self.pool:
+            self.sampler_list = self.pool.map(call_step, self.sampler_list)
+        else:
+            for ii, sampler in enumerate(self.sampler_list):
+                self.sampler_list[ii] = sampler.step()
+
+        if self.nensemble > 1 and self.swap_counter["L2-ensemble"] >= self.L2steps:
+            self.swap_counter["ensemble"] += 1
+            self.swap_counter["L2-ensemble"] = 0
+            self.ensemble_step()
+
+        if self.ntemps > 1 and self.swap_counter["L2-temperature"] >= self.L2steps:
+            self.swap_counter["temperature"] += 1
+            self.swap_counter["L2-temperature"] = 0
+            self.swap_tempered_chains()
+            if self.position < self.adapt_t0 * 10:
+                if self.adapt:
+                    self.adapt_temperatures()
+            elif self.adapt:
+                logger.info(
+                    f"Adaptation of temperature chains finished at step {self.position}"
+                )
+                self.adapt = False
+
+        self.swap_counter["L2-ensemble"] += 1
+        self.swap_counter["L2-temperature"] += 1
+
+    @staticmethod
+    def _get_sample_to_swap(sampler):
+        if sampler.chain.converged is False:
+            v = sampler.chain[-1]
+        else:
+            v = sampler.chain.random_sample
+        logl = v[LOGLKEY]
+        return v, logl
+
+    def swap_tempered_chains(self):
+        if self.pt_ensemble:
+            Eindexs = range(self.nensemble)
+        else:
+            Eindexs = [0]
+        for Eindex in Eindexs:
+            for Tindex in range(self.ntemps - 1):
+                sampleri = self.sampler_dictionary[Tindex][Eindex]
+                vi, logli = self._get_sample_to_swap(sampleri)
+                betai = sampleri.beta
+
+                samplerj = self.sampler_dictionary[Tindex + 1][Eindex]
+                vj, loglj = self._get_sample_to_swap(samplerj)
+                betaj = samplerj.beta
+
+                dbeta = betaj - betai
+                with np.errstate(over="ignore"):
+                    alpha_swap = np.exp(dbeta * (logli - loglj))
+
+                if np.random.uniform(0, 1) <= alpha_swap:
+                    sampleri.chain[-1] = vj
+                    samplerj.chain[-1] = vi
+                    self.sampler_dictionary[Tindex][Eindex] = sampleri
+                    self.sampler_dictionary[Tindex + 1][Eindex] = samplerj
+                    sampleri.pt_accepted += 1
+                else:
+                    sampleri.pt_rejected += 1
+
+    def ensemble_step(self):
+        for Tindex, sampler_list in self.sampler_dictionary.items():
+            if len(sampler_list) > 1:
+                for Eindex, sampler in enumerate(sampler_list):
+                    curr = sampler.chain.current_sample
+                    proposal = self.ensemble_proposal_cycle.get_proposal()
+                    complement = [s.chain for s in sampler_list if s != sampler]
+                    prop, log_factor = proposal(sampler.chain, complement)
+                    logp = sampler.log_prior(prop)
+
+                    if logp == -np.inf:
+                        sampler.reject_proposal(curr, proposal)
+                        self.sampler_dictionary[Tindex][Eindex] = sampler
+                        continue
+
+                    prop[LOGPKEY] = logp
+                    prop[LOGLKEY] = sampler.log_likelihood(prop)
+                    alpha = np.exp(
+                        log_factor
+                        + sampler.beta * prop[LOGLKEY]
+                        + prop[LOGPKEY]
+                        - sampler.beta * curr[LOGLKEY]
+                        - curr[LOGPKEY]
+                    )
+
+                    if np.random.uniform(0, 1) <= alpha:
+                        sampler.accept_proposal(prop, proposal)
+                    else:
+                        sampler.reject_proposal(curr, proposal)
+                    self.sampler_dictionary[Tindex][Eindex] = sampler
+
+    def adapt_temperatures(self):
+        """Adapt the temperature of the chains
+
+        Using the dynamic temperature selection described in arXiv:1501.05823,
+        adapt the chains to target a constant swap ratio. This method is based
+        on github.com/willvousden/ptemcee/tree/master/ptemcee
+        """
+
+        self.primary_sampler.chain.minimum_index_adapt = self.position
+        tt = self.swap_counter["temperature"]
+        for sampler_list in self.sampler_list_of_tempered_lists:
+            betas = np.array([s.beta for s in sampler_list])
+            ratios = np.array([s.acceptance_ratio for s in sampler_list[:-1]])
+
+            # Modulate temperature adjustments with a hyperbolic decay.
+            decay = self.adapt_t0 / (tt + self.adapt_t0)
+            kappa = decay / self.adapt_nu
+
+            # Construct temperature adjustments.
+            dSs = kappa * (ratios[:-1] - ratios[1:])
+
+            # Compute new ladder (hottest and coldest chains don't move).
+            deltaTs = np.diff(1 / betas[:-1])
+            deltaTs *= np.exp(dSs)
+            betas[1:-1] = 1 / (np.cumsum(deltaTs) + 1 / betas[0])
+            for sampler, beta in zip(sampler_list, betas):
+                sampler.beta = beta
+
+    @property
+    def ln_z(self):
+        return self.ln_z_dict.get(self.evidence_method, np.nan)
+
+    @property
+    def ln_z_err(self):
+        return self.ln_z_err_dict.get(self.evidence_method, np.nan)
+
+    def compute_evidence(self, outdir, label, make_plots=True):
+        if self.ntemps == 1:
+            return
+        kwargs = dict(outdir=outdir, label=label, make_plots=make_plots)
+        methods = dict(
+            thermodynamic=self.thermodynamic_integration_evidence,
+            stepping_stone=self.stepping_stone_evidence,
+        )
+        for key, method in methods.items():
+            ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs)
+            self.ln_z_dict[key] = ln_z
+            self.ln_z_err_dict[key] = ln_z_err
+            logger.info(
+                f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method"
+            )
+
+    def compute_evidence_per_ensemble(self, method, kwargs):
+        from scipy.special import logsumexp
+
+        if self.ntemps == 1:
+            return np.nan, np.nan
+
+        lnZ_list = []
+        lnZerr_list = []
+        for index, ptchain in enumerate(self.sampler_list_of_tempered_lists):
+            lnZ, lnZerr = method(ptchain, **kwargs)
+            lnZ_list.append(lnZ)
+            lnZerr_list.append(lnZerr)
+
+        N = len(lnZ_list)
+
+        # Average lnZ
+        lnZ = logsumexp(lnZ_list, b=1.0 / N)
+
+        # Propagate uncertainty in combined evidence
+        lnZerr = 0.5 * logsumexp(2 * np.array(lnZerr_list), b=1.0 / N)
+
+        return lnZ, lnZerr
+
+    def thermodynamic_integration_evidence(
+        self, ptchain, outdir, label, make_plots=True
+    ):
+        """Computes the evidence using thermodynamic integration
+
+        We compute the evidence without the burnin samples, no thinning
+        """
+        from scipy.stats import sem
+
+        betas = []
+        mean_lnlikes = []
+        sem_lnlikes = []
+        for sampler in ptchain:
+            lnlikes = sampler.chain.get_1d_array(LOGLKEY)
+            mindex = sampler.chain.minimum_index
+            lnlikes = lnlikes[mindex:]
+            mean_lnlikes.append(np.mean(lnlikes))
+            sem_lnlikes.append(sem(lnlikes))
+            betas.append(sampler.beta)
+
+        # Convert to array and re-order
+        betas = np.array(betas)[::-1]
+        mean_lnlikes = np.array(mean_lnlikes)[::-1]
+        sem_lnlikes = np.array(sem_lnlikes)[::-1]
+
+        lnZ, lnZerr = self._compute_evidence_from_mean_lnlikes(betas, mean_lnlikes)
+
+        if make_plots:
+            plot_label = f"{label}_E{ptchain[0].Eindex}"
+            self._create_lnZ_plots(
+                betas=betas,
+                mean_lnlikes=mean_lnlikes,
+                outdir=outdir,
+                label=plot_label,
+                sem_lnlikes=sem_lnlikes,
+            )
+
+        return lnZ, lnZerr
+
+    def stepping_stone_evidence(self, ptchain, outdir, label, make_plots=True):
+        """
+        Compute the evidence using the stepping stone approximation.
+
+        See https://arxiv.org/abs/1810.04488 and
+        https://pubmed.ncbi.nlm.nih.gov/21187451/ for details.
+
+        The uncertainty calculation is hopefully combining the evidence in each
+        of the steps.
+
+        Returns
+        -------
+        ln_z: float
+            Estimate of the natural log evidence
+        ln_z_err: float
+            Estimate of the uncertainty in the evidence
+        """
+        # Order in increasing beta
+        ptchain.reverse()
+
+        # Get maximum usable set of samples across the ptchain
+        min_index = max([samp.chain.minimum_index for samp in ptchain])
+        max_index = min([len(samp.chain.get_1d_array(LOGLKEY)) for samp in ptchain])
+        tau = self.tau
+
+        if max_index - min_index <= 1 or np.isinf(tau):
+            return np.nan, np.nan
+
+        # Read in log likelihoods
+        ln_likes = np.array(
+            [samp.chain.get_1d_array(LOGLKEY)[min_index:max_index] for samp in ptchain]
+        )[:-1].T
+
+        # Thin to only independent samples
+        ln_likes = ln_likes[:: int(self.tau), :]
+        steps = ln_likes.shape[0]
+
+        # Calculate delta betas
+        betas = np.array([samp.beta for samp in ptchain])
+
+        ln_z, ln_ratio = self._calculate_stepping_stone(betas, ln_likes)
+
+        # Implementation of the bootstrap method described in Maturana-Russel
+        # et. al. (2019) to estimate the evidence uncertainty.
+        ll = 50  # Block length
+        repeats = 100  # Repeats
+        ln_z_realisations = []
+        try:
+            for _ in range(repeats):
+                idxs = [np.random.randint(i, i + ll) for i in range(steps - ll)]
+                ln_z_realisations.append(
+                    self._calculate_stepping_stone(betas, ln_likes[idxs, :])[0]
+                )
+            ln_z_err = np.std(ln_z_realisations)
+        except ValueError:
+            logger.info("Failed to estimate stepping stone uncertainty")
+            ln_z_err = np.nan
+
+        if make_plots:
+            plot_label = f"{label}_E{ptchain[0].Eindex}"
+            self._create_stepping_stone_plot(
+                means=ln_ratio,
+                outdir=outdir,
+                label=plot_label,
+            )
+
+        return ln_z, ln_z_err
+
+    @staticmethod
+    def _calculate_stepping_stone(betas, ln_likes):
+        from scipy.special import logsumexp
+
+        n_samples = ln_likes.shape[0]
+        d_betas = betas[1:] - betas[:-1]
+        ln_ratio = logsumexp(d_betas * ln_likes, axis=0) - np.log(n_samples)
+        return sum(ln_ratio), ln_ratio
+
+    @staticmethod
+    def _compute_evidence_from_mean_lnlikes(betas, mean_lnlikes):
+        lnZ = np.trapz(mean_lnlikes, betas)
+        z2 = np.trapz(mean_lnlikes[::-1][::2][::-1], betas[::-1][::2][::-1])
+        lnZerr = np.abs(lnZ - z2)
+        return lnZ, lnZerr
+
+    def _create_lnZ_plots(self, betas, mean_lnlikes, outdir, label, sem_lnlikes=None):
+        import matplotlib.pyplot as plt
+
+        logger.debug("Creating thermodynamic evidence diagnostic plot")
+
+        fig, ax1 = plt.subplots()
+        if betas[-1] == 0:
+            x, y = betas[:-1], mean_lnlikes[:-1]
+        else:
+            x, y = betas, mean_lnlikes
+        if sem_lnlikes is not None:
+            ax1.errorbar(x, y, sem_lnlikes, fmt="-")
+        else:
+            ax1.plot(x, y, "-o")
+        ax1.set_xscale("log")
+        ax1.set_xlabel(r"$\beta$")
+        ax1.set_ylabel(r"$\langle \log(\mathcal{L}) \rangle$")
+
+        plt.tight_layout()
+        fig.savefig("{}/{}_beta_lnl.png".format(outdir, label))
+        plt.close()
+
+    def _create_stepping_stone_plot(self, means, outdir, label):
+        import matplotlib.pyplot as plt
+
+        logger.debug("Creating stepping stone evidence diagnostic plot")
+
+        n_steps = len(means)
+
+        fig, axes = plt.subplots(nrows=2, figsize=(8, 10))
+
+        ax = axes[0]
+        ax.plot(np.arange(1, n_steps + 1), means)
+        ax.set_xlabel("$k$")
+        ax.set_ylabel("$r_{k}$")
+
+        ax = axes[1]
+        ax.plot(np.arange(1, n_steps + 1), np.cumsum(means[::1])[::1])
+        ax.set_xlabel("$k$")
+        ax.set_ylabel("Cumulative $\\ln Z$")
+
+        plt.tight_layout()
+        fig.savefig("{}/{}_stepping_stone.png".format(outdir, label))
+        plt.close()
+
+    @property
+    def rejection_sampling_count(self):
+        if self.pt_rejection_sample:
+            counts = 0
+            for column in self.sampler_list_of_tempered_lists:
+                for sampler in column:
+                    counts += sampler.rejection_sampling_count
+            return counts
+        else:
+            return None
+
+
+class BilbyMCMCSampler(object):
+    def __init__(
+        self,
+        convergence_inputs,
+        proposal_cycle=None,
+        beta=1,
+        Tindex=0,
+        Eindex=0,
+        use_ratio=False,
+    ):
+        self.beta = beta
+        self.Tindex = Tindex
+        self.Eindex = Eindex
+        self.use_ratio = use_ratio
+
+        self.parameters = _priors.non_fixed_keys
+        self.ndim = len(self.parameters)
+
+        full_sample_dict = _priors.sample()
+        initial_sample = {
+            k: v for k, v in full_sample_dict.items() if k in _priors.non_fixed_keys
+        }
+        initial_sample = Sample(initial_sample)
+        initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
+        initial_sample[LOGPKEY] = self.log_prior(initial_sample)
+
+        self.chain = Chain(initial_sample=initial_sample)
+        self.set_convergence_inputs(convergence_inputs)
+
+        self.accepted = 0
+        self.rejected = 0
+        self.pt_accepted = 0
+        self.pt_rejected = 0
+        self.rejection_sampling_count = 0
+
+        if isinstance(proposal_cycle, str):
+            # Only print warnings for the primary sampler
+            if Tindex == 0 and Eindex == 0:
+                warn = True
+            else:
+                warn = False
+
+            self.proposal_cycle = proposals.get_proposal_cycle(
+                proposal_cycle, _priors, L1steps=self.chain.L1steps, warn=warn
+            )
+        elif isinstance(proposal_cycle, proposals.ProposalCycle):
+            self.proposal_cycle = proposal_cycle
+        else:
+            raise SamplerError("Proposal cycle not understood")
+
+        if self.Tindex == 0 and self.Eindex == 0:
+            logger.info(f"Using {self.proposal_cycle}")
+
+    def set_convergence_inputs(self, convergence_inputs):
+        for key, val in convergence_inputs._asdict().items():
+            setattr(self.chain, key, val)
+        self.target_nsamples = convergence_inputs.target_nsamples
+        self.stop_after_convergence = convergence_inputs.stop_after_convergence
+
+    def log_likelihood(self, sample):
+        _likelihood.parameters.update(sample.sample_dict)
+
+        if self.use_ratio:
+            logl = _likelihood.log_likelihood_ratio()
+        else:
+            logl = _likelihood.log_likelihood()
+
+        return logl
+
+    def log_prior(self, sample):
+        return _priors.ln_prob(sample.parameter_only_dict)
+
+    def accept_proposal(self, prop, proposal):
+        self.chain.append(prop)
+        self.accepted += 1
+        proposal.accepted += 1
+
+    def reject_proposal(self, curr, proposal):
+        self.chain.append(curr)
+        self.rejected += 1
+        proposal.rejected += 1
+
+    def step(self):
+        if self.stop_after_convergence and self.chain.converged:
+            return self
+
+        internal_steps = 0
+        internal_accepted = 0
+        internal_rejected = 0
+        curr = self.chain.current_sample.copy()
+        while internal_steps < self.chain.L1steps:
+            internal_steps += 1
+            proposal = self.proposal_cycle.get_proposal()
+            prop, log_factor = proposal(self.chain)
+            logp = self.log_prior(prop)
+
+            if np.isinf(logp) or np.isnan(logp):
+                internal_rejected += 1
+                proposal.rejected += 1
+                continue
+
+            prop[LOGPKEY] = logp
+            prop[LOGLKEY] = self.log_likelihood(prop)
+
+            if np.isinf(prop[LOGLKEY]) or np.isnan(prop[LOGLKEY]):
+                internal_rejected += 1
+                proposal.rejected += 1
+                continue
+
+            with np.errstate(over="ignore"):
+                alpha = np.exp(
+                    log_factor
+                    + self.beta * prop[LOGLKEY]
+                    + prop[LOGPKEY]
+                    - self.beta * curr[LOGLKEY]
+                    - curr[LOGPKEY]
+                )
+
+            if np.random.uniform(0, 1) <= alpha:
+                internal_accepted += 1
+                proposal.accepted += 1
+                curr = prop
+                self.chain.current_sample = curr
+            else:
+                internal_rejected += 1
+                proposal.rejected += 1
+
+        self.chain.append(curr)
+        self.rejected += internal_rejected
+        self.accepted += internal_accepted
+        return self
+
+    @property
+    def nsamples(self):
+        nsamples = self.chain.nsamples
+        if nsamples > self.target_nsamples and self.chain.converged is False:
+            logger.debug(f"Temperature {self.Tindex} chain reached convergence")
+            self.chain.converged = True
+        return nsamples
+
+    @property
+    def acceptance_ratio(self):
+        return self.accepted / (self.accepted + self.rejected)
+
+    @property
+    def samples(self):
+        if self.beta == 1:
+            return self.chain.samples
+        else:
+            return self.rejection_sample_zero_temperature_samples(print_message=True)
+
+    def rejection_sample_zero_temperature_samples(self, print_message=False):
+        beta = self.beta
+        chain = self.chain
+        hot_samples = pd.DataFrame(
+            chain._chain_array[chain.minimum_index : chain.position], columns=chain.keys
+        )
+        if len(hot_samples) == 0:
+            logger.debug(
+                f"Rejection sampling for Temp {self.Tindex} failed: "
+                "no usable hot samples"
+            )
+            return hot_samples
+
+        # Pull out log likelihood
+        zerotemp_logl = hot_samples[LOGLKEY]
+
+        # Revert to true likelihood if needed
+        if _use_ratio:
+            zerotemp_logl += _likelihood.noise_log_likelihood()
+
+        # Calculate normalised weights
+        log_weights = (1 - beta) * zerotemp_logl
+        max_weight = np.max(log_weights)
+        unnormalised_weights = np.exp(log_weights - max_weight)
+        weights = unnormalised_weights / np.sum(unnormalised_weights)
+
+        # Rejection sample
+        samples = rejection_sample(hot_samples, weights)
+
+        # Logging
+        self.rejection_sampling_count = len(samples)
+
+        if print_message:
+            logger.info(
+                f"Rejection sampling Temp {self.Tindex}, beta={beta:0.2f} "
+                f"yielded {len(samples)} samples"
+            )
+        return samples
+
+
+# Methods used to aid parallelisation:
+
+
+def call_step(sampler):
+    sampler = sampler.step()
+    return sampler
+
+
+_likelihood = None
+_priors = None
+_search_parameter_keys = None
+_use_ratio = False
+
+
+def _initialize_global_variables(
+    likelihood,
+    priors,
+    search_parameter_keys,
+    use_ratio,
+):
+    """
+    Store a global copy of the likelihood, priors, and search keys for
+    multiprocessing.
+    """
+    global _likelihood
+    global _priors
+    global _search_parameter_keys
+    global _use_ratio
+    _likelihood = likelihood
+    _priors = priors
+    _search_parameter_keys = search_parameter_keys
+    _use_ratio = use_ratio
diff --git a/bilby/bilby_mcmc/utils.py b/bilby/bilby_mcmc/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..25c8ba2327d7bacc9927c6796aeeb48b09647155
--- /dev/null
+++ b/bilby/bilby_mcmc/utils.py
@@ -0,0 +1,38 @@
+from collections import namedtuple
+
+LOGLKEY = "logl"
+LOGLLATEXKEY = r"$\log\mathcal{L}$"
+LOGPKEY = "logp"
+LOGPLATEXKEY = r"$\log\pi$"
+
+ConvergenceInputs = namedtuple(
+    "ConvergenceInputs",
+    [
+        "autocorr_c",
+        "burn_in_nact",
+        "thin_by_nact",
+        "fixed_discard",
+        "target_nsamples",
+        "stop_after_convergence",
+        "L1steps",
+        "L2steps",
+        "min_tau",
+        "fixed_tau",
+        "tau_window",
+    ],
+)
+
+ParallelTemperingInputs = namedtuple(
+    "ParallelTemperingInputs",
+    [
+        "ntemps",
+        "nensemble",
+        "Tmax",
+        "Tmax_from_SNR",
+        "initial_betas",
+        "adapt",
+        "adapt_t0",
+        "adapt_nu",
+        "pt_ensemble",
+    ],
+)
diff --git a/bilby/core/result.py b/bilby/core/result.py
index 00b66869e14586eb330da5d7a4fd6447f9e941bc..0c234df022ac513507f3a4547c07a29a93a55bbf 100644
--- a/bilby/core/result.py
+++ b/bilby/core/result.py
@@ -1776,7 +1776,7 @@ class ResultList(list):
         # check which kind of sampler was used: MCMC or Nested Sampling
         if result._nested_samples is not None:
             posteriors, result = self._combine_nested_sampled_runs(result)
-        elif result.sampler in ["bilbymcmc"]:
+        elif result.sampler in ["bilby_mcmc", "bilbymcmc"]:
             posteriors, result = self._combine_mcmc_sampled_runs(result)
         else:
             posteriors = [res.posterior for res in self]
diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py
index 93202387f35767e4e58dad4298ca309c4b441506..e586f138510cf5c8637b3faa27b4d3753a26b131 100644
--- a/bilby/core/sampler/__init__.py
+++ b/bilby/core/sampler/__init__.py
@@ -22,10 +22,11 @@ from .pymultinest import Pymultinest
 from .ultranest import Ultranest
 from .fake_sampler import FakeSampler
 from .dnest4 import DNest4
+from bilby.bilby_mcmc import Bilby_MCMC
 from . import proposal
 
 IMPLEMENTED_SAMPLERS = {
-    'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty,
+    'bilby_mcmc': Bilby_MCMC, 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty,
     'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai,
     'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler,
     'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord,
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 548e1a9b35c6b1e5a5e86264ad422d00844d28cf..5e25a7e54247fa0dcc7f5e9841e308bbfd071173 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -668,6 +668,10 @@ class SamplerError(Error):
     """ Base class for Error related to samplers in this module """
 
 
+class ResumeError(Error):
+    """ Class for errors arising from resuming runs """
+
+
 class SamplerNotInstalledError(SamplerError):
     """ Base class for Error raised by not installed samplers """
 
diff --git a/bilby/core/utils/io.py b/bilby/core/utils/io.py
index 080e591f5c181aa7460356ce9ed34f093fa74521..49f77b7739baf8bf3490185bad3bdf67a30b699c 100644
--- a/bilby/core/utils/io.py
+++ b/bilby/core/utils/io.py
@@ -29,6 +29,7 @@ class BilbyJsonEncoder(json.JSONEncoder):
     def default(self, obj):
         from ..prior import MultivariateGaussianDist, Prior, PriorDict
         from ...gw.prior import HealPixMapPriorDist
+        from ...bilby_mcmc.proposals import ProposalCycle
         if isinstance(obj, np.integer):
             return int(obj)
         if isinstance(obj, np.floating):
@@ -39,6 +40,8 @@ class BilbyJsonEncoder(json.JSONEncoder):
             return {'__prior__': True, '__module__': obj.__module__,
                     '__name__': obj.__class__.__name__,
                     'kwargs': dict(obj.get_instantiation_dict())}
+        if isinstance(obj, ProposalCycle):
+            return str(obj)
         try:
             from astropy import cosmology as cosmo, units
             if isinstance(obj, cosmo.FLRW):
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 291cf412da04f2626de4d24dc4a6216c89438106..23356d48970b911d0d1d85923010e7124a005dba 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -957,7 +957,8 @@ class CalibrationPriorDict(PriorDict):
 
     @staticmethod
     def from_envelope_file(envelope_file, minimum_frequency,
-                           maximum_frequency, n_nodes, label):
+                           maximum_frequency, n_nodes, label,
+                           boundary="reflective"):
         """
         Load in the calibration envelope.
 
@@ -980,6 +981,8 @@ class CalibrationPriorDict(PriorDict):
             Number of nodes for the spline.
         label: str
             Label for the names of the parameters, e.g., `recalib_H1_`
+        bounadry: None, 'reflective', 'periodic'
+            The type of prior boundary to assign
 
         Returns
         =======
@@ -1013,14 +1016,14 @@ class CalibrationPriorDict(PriorDict):
             prior[name] = Gaussian(mu=amplitude_mean_nodes[ii],
                                    sigma=amplitude_sigma_nodes[ii],
                                    name=name, latex_label=latex_label,
-                                   boundary='reflective')
+                                   boundary=boundary)
         for ii in range(n_nodes):
             name = "recalib_{}_phase_{}".format(label, ii)
             latex_label = "$\\phi^{}_{}$".format(label, ii)
             prior[name] = Gaussian(mu=phase_mean_nodes[ii],
                                    sigma=phase_sigma_nodes[ii],
                                    name=name, latex_label=latex_label,
-                                   boundary='reflective')
+                                   boundary=boundary)
         for ii in range(n_nodes):
             name = "recalib_{}_frequency_{}".format(label, ii)
             latex_label = "$f^{}_{}$".format(label, ii)
diff --git a/optional_requirements.txt b/optional_requirements.txt
index aad48e7d087f2717a6f42483fc95805d3e763fd3..86acd396e5bba1759be86b07ade32ff1b3dfad60 100644
--- a/optional_requirements.txt
+++ b/optional_requirements.txt
@@ -5,3 +5,5 @@ theano
 plotly
 tables
 pyfftw
+scikit-learn
+nflows
diff --git a/setup.py b/setup.py
index 575b97ce124f0ba2e93f4fbd29268013588a333f..49747e8bf0a8e619840aff5f55e0a623d1de102e 100644
--- a/setup.py
+++ b/setup.py
@@ -85,7 +85,8 @@ setup(name='bilby',
       version=VERSION,
       packages=['bilby', 'bilby.core', 'bilby.core.prior', 'bilby.core.sampler',
                 'bilby.core.utils', 'bilby.gw', 'bilby.gw.detector',
-                'bilby.gw.sampler', 'bilby.hyper', 'bilby.gw.eos', 'cli_bilby'],
+                'bilby.gw.sampler', 'bilby.hyper', 'bilby.gw.eos', 'bilby.bilby_mcmc',
+                'cli_bilby'],
       package_dir={'bilby': 'bilby', 'cli_bilby': 'cli_bilby'},
       package_data={'bilby.gw': ['prior_files/*'],
                     'bilby.gw.detector': ['noise_curves/*.txt', 'detectors/*'],
diff --git a/test/bilby_mcmc/chain.py b/test/bilby_mcmc/chain.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8017dc814369fedeea6f4cc4f3121971fb7b69a
--- /dev/null
+++ b/test/bilby_mcmc/chain.py
@@ -0,0 +1,227 @@
+import os
+import shutil
+import unittest
+
+import bilby
+from bilby.bilby_mcmc.chain import Chain, Sample, calculate_tau
+from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY
+from bilby.core.sampler.base_sampler import SamplerError
+import numpy as np
+import pandas as pd
+
+
+class TestChain(unittest.TestCase):
+    def setUp(self):
+        self.initial_sample = self.create_random_sample()
+        self.outdir = "chain_test"
+        if os.path.isdir(self.outdir) is False:
+            os.mkdir(self.outdir)
+
+    def tearDown(self):
+        if os.path.isdir(self.outdir):
+            shutil.rmtree(self.outdir)
+
+    def create_random_sample(self):
+        return Sample({
+            "a": np.random.normal(0, 1),
+            "b": np.random.normal(0, 1),
+            LOGLKEY: np.random.normal(0, 1),
+            LOGPKEY: -1
+        })
+
+    def create_chain(self, n=1000):
+        chain = Chain(initial_sample=self.initial_sample)
+        for i in range(n):
+            chain.append(self.create_random_sample())
+        return chain
+
+    def test_initialize(self):
+        chain = Chain(initial_sample=self.initial_sample)
+        self.assertEqual(chain.position, 0)
+
+    def test_append(self):
+        chain = Chain(initial_sample=self.initial_sample)
+        chain.append(self.create_random_sample())
+        self.assertEqual(chain.position, 1)
+        self.assertEqual(len(chain.get_1d_array('a')), 2)
+
+    def test_append_within_init_space(self):
+        chain = Chain(initial_sample=self.initial_sample)
+        N = chain.block_length - 1
+        for i in range(N):
+            chain.append(self.create_random_sample())
+
+        self.assertEqual(chain.position, N)
+
+        # N samples + 1 initial position
+        self.assertEqual(len(chain.get_1d_array('a')), N + 1)
+
+    def test_append_with_extending(self):
+        block_length = 100
+        chain = Chain(initial_sample=self.initial_sample, block_length=block_length)
+
+        # Check the array is the block length
+        self.assertEqual(len(chain._chain_array), block_length)
+        for i in range(3 * block_length):
+            chain.append(self.create_random_sample())
+
+        # Check the array is now longer than the block length (succesfully extended)
+        self.assertEqual(len(chain._chain_array), 4 * block_length)
+
+    def test_get_item(self):
+        chain = self.create_chain()
+        tenth_sample = chain[10]
+        self.assertTrue(isinstance(tenth_sample, Sample))
+
+        last_sample = chain[-1]
+        self.assertEqual(last_sample, chain.current_sample)
+
+        with self.assertRaises(SamplerError):
+            chain[chain.position + 10]
+
+    def test_set_item(self):
+        chain = self.create_chain()
+        s = self.create_random_sample()
+
+        chain[10] = s
+        self.assertEqual(s, chain[10])
+
+        chain[-1] = s
+        self.assertEqual(s, chain[-1])
+
+    def test_random_sample(self):
+        chain = self.create_chain()
+        c1 = chain.random_sample
+        c2 = chain.random_sample
+        self.assertNotEqual(c1, c2)
+
+    def test_fixed_discard(self):
+        chain = self.create_chain()
+        self.assertEqual(chain.fixed_discard, 0)
+        chain.fixed_discard = 10
+        self.assertEqual(chain.fixed_discard, 10)
+
+    def test_minimum_index(self):
+        chain = self.create_chain()
+        self.assertEqual(chain.minimum_index, 0)
+
+        chain._last_minimum_index = (chain.position, 10, "I")
+        self.assertEqual(chain.minimum_index, 10)
+        chain._last_minimum_index = (0, 0, "I")
+
+        chain.fixed_discard = 200
+        self.assertEqual(chain.minimum_index, 200)
+        chain._last_minimum_index = (0, 0, "I")
+
+        chain.fixed_discard = 100000
+        self.assertEqual(chain.minimum_index, 100000)
+
+    def test_tau(self):
+        chain = self.create_chain(n=1000)
+        self.assertGreaterEqual(chain.tau, chain.min_tau)
+        self.assertLess(chain.tau, np.inf)
+        self.assertEqual(chain.tau, chain.max_tau_dict[chain.position])
+        self.assertEqual(chain.tau, chain.tau_last)
+
+        # Check the cached tau calc works
+        for i in range(5):
+            chain.append(self.create_random_sample())
+        chain.tau
+        self.assertEqual(chain.cached_tau_count, 1)
+
+    def test_nsamples(self):
+        chain = self.create_chain(n=1000)
+        self.assertGreaterEqual(chain.nsamples, 1)
+        self.assertLessEqual(chain.nsamples, chain.position)
+
+    def test_thin(self):
+        chain = self.create_chain(n=1000)
+        self.assertEqual(chain.thin, int(chain.thin_by_nact * chain.tau))
+
+    def test_samples(self):
+        chain = self.create_chain(n=1000)
+        samples = chain.samples
+        self.assertTrue(isinstance(samples, pd.DataFrame))
+        self.assertTrue("a" in samples)
+        self.assertTrue("b" in samples)
+        self.assertTrue(LOGLKEY in samples)
+        self.assertTrue(LOGPKEY in samples)
+
+    def test_plot(self):
+        chain = self.create_chain(n=1000)
+        chain.plot(outdir=self.outdir, label="test")
+        self.assertTrue(os.path.exists(f"{self.outdir}/test_checkpoint_trace.png"))
+        priors = dict(
+            a=bilby.core.prior.Uniform(-10, 10, latex_label='a'),
+            b=bilby.core.prior.Uniform(-10, 10),
+        )
+        chain.thin_by_nact = 0.5
+        chain.plot(outdir=self.outdir, label="test", priors=priors)
+        self.assertTrue(os.path.exists(f"{self.outdir}/test_checkpoint_trace.png"))
+
+
+class TestSample(unittest.TestCase):
+    def setUp(self):
+        self.sample_dict = dict(a=1, b=2)
+
+    def tearDown(self):
+        del self.sample_dict
+
+    def test_init(self):
+        s = Sample(self.sample_dict)
+        self.assertEqual(s.keys, list(self.sample_dict.keys()))
+
+    def test_dict_access(self):
+        s = Sample(self.sample_dict)
+        for key in s.keys:
+            self.assertEqual(s[key], self.sample_dict[key])
+
+    def test_list_access(self):
+        s = Sample(self.sample_dict)
+        slist = s.list
+        self.assertEqual(slist, [self.sample_dict['a'], self.sample_dict['b']])
+
+    def test_setitem(self):
+        s = Sample(self.sample_dict)
+
+        # Set existing parameter
+        s['a'] = 100
+        self.assertEqual(s['a'], 100)
+
+        # Add parameter
+        s['c'] = 100
+        self.assertEqual(s['c'], 100)
+
+    def test_parameter_only_dict(self):
+        s = Sample(self.sample_dict)
+        self.assertEqual(s.parameter_only_dict, dict(a=1, b=2))
+
+    def test_update(self):
+        sample_dict = dict(a=1, b=2)
+        curr = Sample(sample_dict)
+        prop = curr.copy()
+        prop['a'] = 200
+        self.assertEqual(prop['a'], 200)
+        self.assertEqual(curr['a'], 1)
+
+
+class TestACT(unittest.TestCase):
+    def test_act_normal(self):
+        x = np.random.normal(0, 1, 1000)
+        tau = calculate_tau(x)
+        self.assertLess(tau, 10)
+
+    def test_act_identical(self):
+        x = np.array([0] * 1000)
+        tau = calculate_tau(x)
+        self.assertEqual(tau, np.inf)
+
+    def test_act_long(self):
+        t = np.linspace(0, 1, 1000)
+        x = np.sin(2 * np.pi * t)
+        tau = calculate_tau(x)
+        self.assertGreater(tau, 10)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/bilby_mcmc/proposals.py b/test/bilby_mcmc/proposals.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cf71f213e01ddaac2ec90b989e7a6f666d86ae7
--- /dev/null
+++ b/test/bilby_mcmc/proposals.py
@@ -0,0 +1,187 @@
+import os
+import copy
+import shutil
+import unittest
+import inspect
+import sys
+import time
+import bilby
+from bilby.bilby_mcmc.chain import Chain, Sample
+from bilby.bilby_mcmc import proposals
+from bilby.bilby_mcmc.utils import LOGLKEY, LOGPKEY
+import numpy as np
+
+
+class GivenProposal(proposals.BaseProposal):
+    """ A simple proposal class used for testing """
+    def __init__(self, priors, weight=1, subset=None, sigma=0.01):
+        super(GivenProposal, self).__init__(priors, weight, subset)
+
+    def propose(self, chain):
+        log_factor = 0
+        return self.given_sample, log_factor
+
+
+class TestBaseProposals(unittest.TestCase):
+    def create_priors(self, ndim=2, boundary=None):
+        priors = bilby.core.prior.PriorDict({
+            f'x{i}': bilby.core.prior.Uniform(-10, 10, name=f'x{i}', boundary=boundary)
+            for i in range(ndim)
+        })
+        priors["fixedA"] = bilby.core.prior.DeltaFunction(1)
+        return priors
+
+    def create_random_sample(self, ndim=2):
+        p = {f"x{i}": np.random.normal(0, 1) for i in range(ndim)}
+        p[LOGLKEY] = np.random.normal(0, 1)
+        p[LOGPKEY] = -1
+        p["fixedA"] = 1
+        return Sample(p)
+
+    def create_chain(self, n=1000, ndim=2):
+        initial_sample = self.create_random_sample(ndim)
+        chain = Chain(initial_sample=initial_sample)
+        for i in range(n):
+            chain.append(self.create_random_sample(ndim))
+        return chain
+
+    def test_GivenProposal(self):
+        priors = self.create_priors()
+        chain = self.create_chain()
+        proposal = GivenProposal(priors)
+        proposal.given_sample = self.create_random_sample()
+        prop, _ = proposal(chain)
+        self.assertEqual(prop, proposal.given_sample)
+
+    def test_noboundary(self):
+        priors = self.create_priors()
+        chain = self.create_chain()
+        proposal = GivenProposal(priors)
+
+        sample = self.create_random_sample()
+        sample["x0"] = priors["x0"].maximum + 0.5
+        proposal.given_sample = sample
+
+        prop, _ = proposal(chain)
+        self.assertEqual(prop, proposal.given_sample)
+        self.assertEqual(prop["x0"], priors["x0"].maximum + 0.5)
+
+    def test_periodic_boundary_above(self):
+        priors = self.create_priors(boundary="periodic")
+        chain = self.create_chain()
+        proposal = GivenProposal(priors)
+
+        sample = self.create_random_sample()
+        sample["x0"] = priors["x0"].maximum + 0.5
+        proposal.given_sample = copy.deepcopy(sample)
+
+        prop, _ = proposal(chain)
+        self.assertFalse(prop["x0"] == priors["x0"].maximum + 0.5)
+        self.assertEqual(prop["x0"], priors["x0"].minimum + 0.5)
+
+    def test_periodic_boundary_below(self):
+        priors = self.create_priors(boundary="periodic")
+        chain = self.create_chain()
+        proposal = GivenProposal(priors)
+
+        sample = self.create_random_sample()
+        sample["x0"] = priors["x0"].minimum - 0.5
+        proposal.given_sample = copy.deepcopy(sample)
+
+        prop, _ = proposal(chain)
+        self.assertFalse(prop["x0"] == priors["x0"].minimum - 0.5)
+        self.assertEqual(prop["x0"], priors["x0"].maximum - 0.5)
+
+
+class TestProposals(TestBaseProposals):
+    def setUp(self):
+        self.outdir = "chain_test"
+        if os.path.isdir(self.outdir) is False:
+            os.mkdir(self.outdir)
+
+    def tearDown(self):
+        if os.path.isdir(self.outdir):
+            shutil.rmtree(self.outdir)
+
+    def get_simple_proposals(self):
+        clsmembers = inspect.getmembers(
+            sys.modules[proposals.__name__], inspect.isclass
+        )
+        clsmembers_clean = []
+        for name, cls in clsmembers:
+            a = "Proposal" in name
+            b = "Base" not in name
+            c = "Ensemble" not in name
+            d = "Phase" not in name
+            e = "Polarisation" not in name
+            f = "Cycle" not in name
+            g = "KDE" not in name
+            h = "NormalizingFlow" not in name
+            if a * b * c * d * e * f * g * h:
+                clsmembers_clean.append((name, cls))
+
+        return clsmembers_clean
+
+    def proposal_check(self, prop, ndim=2, N=100):
+        chain = self.create_chain(ndim=ndim)
+
+        print(f"Testing {prop.__class__.__name__}")
+        # Timing and return type
+        start = time.time()
+        for _ in range(N):
+            p, w = prop(chain)
+            chain.append(p)
+        dt = 1e3 * (time.time() - start) / N
+        print(f"Testing {prop.__class__.__name__}: dt~{dt:0.2g} [ms]")
+
+        self.assertTrue(isinstance(p, Sample))
+        self.assertTrue(isinstance(w, (int, float)))
+
+    def test_proposal_return_type(self):
+        priors = self.create_priors()
+        for name, cls in self.get_simple_proposals():
+            prop = cls(priors)
+            self.proposal_check(prop)
+
+    def test_KDE_proposal(self):
+        priors = self.create_priors()
+        prop = proposals.KDEProposal(priors)
+        self.proposal_check(prop, N=20000)
+        self.assertTrue(prop.trained)
+
+    def test_GMM_proposal(self):
+        priors = self.create_priors()
+        prop = proposals.GMMProposal(priors)
+        self.proposal_check(prop, N=20000)
+        self.assertTrue(prop.trained)
+
+    def test_NF_proposal(self):
+        priors = self.create_priors()
+        chain = self.create_chain(10000)
+        prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
+        prop.steps_since_refit = 9999
+        start = time.time()
+        p, w = prop(chain)
+        dt = time.time() - start
+        print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
+        self.assertTrue(prop.trained)
+
+        self.proposal_check(prop)
+
+    def test_NF_proposal_15D(self):
+        ndim = 15
+        priors = self.create_priors(ndim)
+        chain = self.create_chain(10000, ndim=ndim)
+        prop = proposals.NormalizingFlowProposal(priors, first_fit=10000)
+        prop.steps_since_refit = 9999
+        start = time.time()
+        p, w = prop(chain)
+        dt = time.time() - start
+        print(f"Training for {prop.__class__.__name__} took dt~{dt:0.2g} [s]")
+        self.assertTrue(prop.trained)
+
+        self.proposal_check(prop, ndim=ndim)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/bilby_mcmc/sampler.py b/test/bilby_mcmc/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b5ff98b7f40445a07d4d3aeae682cf662bc95d
--- /dev/null
+++ b/test/bilby_mcmc/sampler.py
@@ -0,0 +1,86 @@
+import os
+import shutil
+import unittest
+
+import bilby
+from bilby.bilby_mcmc.sampler import BilbyMCMC, BilbyMCMCSampler, _initialize_global_variables
+from bilby.bilby_mcmc.utils import ConvergenceInputs
+from bilby.core.sampler.base_sampler import SamplerError
+import numpy as np
+import pandas as pd
+
+
+class TestBilbyMCMCSampler(unittest.TestCase):
+    def setUp(self):
+        default_kwargs = BilbyMCMC.default_kwargs
+        default_kwargs["target_nsamples"] = 100
+        default_kwargs["L1steps"] = 1
+        self.convergence_inputs = ConvergenceInputs(
+            **{key: default_kwargs[key] for key in ConvergenceInputs._fields}
+        )
+
+        self.outdir = "bilby_mcmc_sampler_test"
+        if os.path.isdir(self.outdir) is False:
+            os.mkdir(self.outdir)
+
+        def model(time, m, c):
+            return time * m + c
+        injection_parameters = dict(m=0.5, c=0.2)
+        sampling_frequency = 10
+        time_duration = 10
+        time = np.arange(0, time_duration, 1 / sampling_frequency)
+        N = len(time)
+        sigma = np.random.normal(1, 0.01, N)
+        data = model(time, **injection_parameters) + np.random.normal(0, sigma, N)
+        likelihood = bilby.likelihood.GaussianLikelihood(time, data, model, sigma)
+
+        # From hereon, the syntax is exactly equivalent to other bilby examples
+        # We make a prior
+        priors = dict()
+        priors['m'] = bilby.core.prior.Uniform(0, 5, 'm')
+        priors['c'] = bilby.core.prior.Uniform(-2, 2, 'c')
+        priors = bilby.core.prior.PriorDict(priors)
+
+        search_parameter_keys = ['m', 'c']
+        use_ratio = False
+
+        _initialize_global_variables(likelihood, priors, search_parameter_keys, use_ratio)
+
+    def tearDown(self):
+        if os.path.isdir(self.outdir):
+            shutil.rmtree(self.outdir)
+
+    def test_None_proposal_cycle(self):
+        with self.assertRaises(SamplerError):
+            BilbyMCMCSampler(
+                convergence_inputs=self.convergence_inputs,
+                proposal_cycle=None,
+                beta=1,
+                Tindex=0,
+                Eindex=0,
+                use_ratio=False
+            )
+
+    def test_default_proposal_cycle(self):
+        sampler = BilbyMCMCSampler(
+            convergence_inputs=self.convergence_inputs,
+            proposal_cycle="default_noNFnoGMnoKD",
+            beta=1,
+            Tindex=0,
+            Eindex=0,
+            use_ratio=False
+        )
+
+        nsteps = 0
+        while sampler.nsamples < 500:
+            sampler.step()
+            nsteps += 1
+        self.assertEqual(sampler.chain.position, nsteps)
+        self.assertEqual(sampler.accepted + sampler.rejected, nsteps)
+        self.assertTrue(isinstance(sampler.samples, pd.DataFrame))
+        for prop in sampler.proposal_cycle.proposal_list:
+            self.assertGreater(prop.n, 50)
+
+
+if __name__ == "__main__":
+    unittest.main()
diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py
index 46f09bbb6988e839d8c4390f278b470b1d882fab..2b225c01bdf318ff0eadfd85597aeed95d2a9f4c 100644
--- a/test/integration/sampler_run_test.py
+++ b/test/integration/sampler_run_test.py
@@ -23,7 +23,7 @@ class TestRunningSamplers(unittest.TestCase):
         )
 
         self.priors = bilby.core.prior.PriorDict()
-        self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="reflective")
+        self.priors["m"] = bilby.core.prior.Uniform(0, 5, boundary="periodic")
         self.priors["c"] = bilby.core.prior.Uniform(-2, 2, boundary="reflective")
         bilby.core.utils.check_directory_exists_and_if_not_mkdir("outdir")
 
@@ -189,6 +189,13 @@ class TestRunningSamplers(unittest.TestCase):
             sampler='ultranest', save=False,
         )
 
+    def test_run_bilby_mcmc(self):
+        _ = bilby.run_sampler(
+            likelihood=self.likelihood, priors=self.priors,
+            sampler="bilby_mcmc", nsamples=200, save=False,
+            printdt=1,
+        )
+
 
 if __name__ == "__main__":
     unittest.main()