From 1b7c1826b16c45646abfdea2d1bdcaab34e5fe0e Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Thu, 13 Apr 2023 15:45:06 +0000
Subject: [PATCH] bilby-mcmc updates

---
 bilby/bilby_mcmc/chain.py                     |   9 +-
 bilby/bilby_mcmc/proposals.py                 | 144 +++++++++++++--
 bilby/bilby_mcmc/sampler.py                   |  35 +++-
 bilby/core/__init__.py                        |   2 +-
 bilby/core/fisher.py                          | 173 ++++++++++++++++++
 bilby/core/sampler/base_sampler.py            |   1 +
 bilby/gw/prior.py                             |  32 +++-
 bilby/gw/source.py                            |  11 ++
 .../linear_regression_with_Fisher.py          |  71 +++++++
 test/bilby_mcmc/test_proposals.py             |   2 +
 10 files changed, 446 insertions(+), 34 deletions(-)
 create mode 100644 bilby/core/fisher.py
 create mode 100644 examples/core_examples/linear_regression_with_Fisher.py

diff --git a/bilby/bilby_mcmc/chain.py b/bilby/bilby_mcmc/chain.py
index 66d6c97f7..9aec43421 100644
--- a/bilby/bilby_mcmc/chain.py
+++ b/bilby/bilby_mcmc/chain.py
@@ -157,7 +157,7 @@ class Chain(object):
 
     @property
     def minimum_index(self):
-        """This calculated a minimum index from which to discard samples
+        """This calculates a minimum index from which to discard samples
 
         A number of methods are provided for the calculation. A subset are
         switched off (by `if False` statements) for future development
@@ -342,7 +342,12 @@ class Chain(object):
     @property
     def nsamples(self):
         nuseable_steps = self.position - self.minimum_index
-        return int(nuseable_steps / (self.thin_by_nact * self.tau))
+        n_independent_samples = nuseable_steps / self.tau
+        nsamples = int(n_independent_samples / self.thin_by_nact)
+        if nuseable_steps >= nsamples:
+            return nsamples
+        else:
+            return 0
 
     @property
     def nsamples_last(self):
diff --git a/bilby/bilby_mcmc/proposals.py b/bilby/bilby_mcmc/proposals.py
index 90e2f567d..74ce59097 100644
--- a/bilby/bilby_mcmc/proposals.py
+++ b/bilby/bilby_mcmc/proposals.py
@@ -6,6 +6,7 @@ import numpy as np
 from scipy.spatial.distance import jensenshannon
 from scipy.stats import gaussian_kde
 
+from ..core.fisher import FisherMatrixPosteriorEstimator
 from ..core.prior import PriorDict
 from ..core.sampler.base_sampler import SamplerError
 from ..core.utils import logger, reflect
@@ -61,6 +62,9 @@ class BaseProposal(object):
             self.parameters = [p for p in self.parameters if p in subset]
             self._str_attrs.append("parameters")
 
+        if len(self.parameters) == 0:
+            raise ValueError("Proposal requested with zero parameters")
+
         self.ndim = len(self.parameters)
 
         self.prior_boundary_dict = {key: priors[key].boundary for key in priors}
@@ -129,10 +133,16 @@ class BaseProposal(object):
         val_normalised_reflected = reflect(np.array(val_normalised))
         return minimum + width * val_normalised_reflected
 
-    def __call__(self, chain):
-        sample, log_factor = self.propose(chain)
+    def __call__(self, chain, likelihood=None, priors=None):
+
+        if getattr(self, "needs_likelihood_and_priors", False):
+            sample, log_factor = self.propose(chain, likelihood, priors)
+        else:
+            sample, log_factor = self.propose(chain)
+
         if log_factor == 0:
             sample = self.apply_boundaries(sample)
+
         return sample, log_factor
 
     @abstractmethod
@@ -459,7 +469,7 @@ class DensityEstimateProposal(BaseProposal):
 
         # Print a log message
         took = time.time() - start
-        logger.info(
+        logger.debug(
             f"{self.density_name} construction at {self.steps_since_refit} finished"
             f" for length {chain.position} chain, took {took:0.2f}s."
             f" Current accept-ratio={self.acceptance_ratio:0.2f}"
@@ -480,7 +490,7 @@ class DensityEstimateProposal(BaseProposal):
                 fail_parameters.append(key)
 
         if len(fail_parameters) > 0:
-            logger.info(
+            logger.debug(
                 f"{self.density_name} construction failed verification and is discarded"
             )
             self.density = current_density
@@ -493,7 +503,10 @@ class DensityEstimateProposal(BaseProposal):
         # Check if we refit
         testA = self.steps_since_refit >= self.next_refit_time
         if testA:
-            self.refit(chain)
+            try:
+                self.refit(chain)
+            except Exception as e:
+                logger.warning(f"Failed to refit chain due to error {e}")
 
         # If KDE is yet to be fitted, use the fallback
         if self.trained is False:
@@ -656,7 +669,7 @@ class NormalizingFlowProposal(DensityEstimateProposal):
         return np.power(max_js, 2)
 
     def train(self, chain):
-        logger.info("Starting NF training")
+        logger.debug("Starting NF training")
 
         import torch
 
@@ -687,14 +700,14 @@ class NormalizingFlowProposal(DensityEstimateProposal):
                     validation_samples, training_samples_draw
                 )
                 if max_js_bits < max_js_threshold:
-                    logger.info(
+                    logger.debug(
                         f"Training complete after {epoch} steps, "
                         f"max_js_bits={max_js_bits:0.5f}<{max_js_threshold}"
                     )
                     break
 
         took = time.time() - start
-        logger.info(
+        logger.debug(
             f"Flow training step ({self.steps_since_refit}) finished"
             f" for length {chain.position} chain, took {took:0.2f}s."
             f" Current accept-ratio={self.acceptance_ratio:0.2f}"
@@ -715,7 +728,10 @@ class NormalizingFlowProposal(DensityEstimateProposal):
         # Check if we retrain the NF
         testA = self.steps_since_refit >= self.next_refit_time
         if testA:
-            self.train(chain)
+            try:
+                self.train(chain)
+            except Exception as e:
+                logger.warning(f"Failed to retrain chain due to error {e}")
 
         if self.trained is False:
             return self.fallback.propose(chain)
@@ -772,6 +788,64 @@ class FixedJumpProposal(BaseProposal):
         return self.scale * np.random.normal()
 
 
+class FisherMatrixProposal(AdaptiveGaussianProposal):
+    needs_likelihood_and_priors = True
+    """Fisher Matrix Proposals
+
+    Uses a finite differencing approach motivated by BayesWave (see, e.g.
+    https://arxiv.org/abs/1410.3835). The inverse Fisher Information Matrix
+    is calculated from the current sample, then proposals are drawn from a
+    multivariate Gaussian and scaled by an adaptive parameter.
+    """
+
+    def __init__(
+        self,
+        priors,
+        subset=None,
+        weight=1,
+        update_interval=100,
+        scale_init=1e0,
+        fd_eps=1e-6,
+        adapt=False,
+    ):
+        super(FisherMatrixProposal, self).__init__(
+            priors, weight, subset, scale_init=scale_init
+        )
+        self.update_interval = update_interval
+        self.steps_since_update = update_interval
+        self.adapt = adapt
+        self.mean = np.zeros(len(self.parameters))
+        self.fd_eps = fd_eps
+
+    def propose(self, chain, likelihood, priors):
+        sample = chain.current_sample
+        if self.adapt:
+            self.update_scale(chain)
+        if self.steps_since_update >= self.update_interval:
+            fmp = FisherMatrixPosteriorEstimator(
+                likelihood, priors, parameters=self.parameters, fd_eps=self.fd_eps
+            )
+            try:
+                self.iFIM = fmp.calculate_iFIM(sample.dict)
+            except (RuntimeError, ValueError) as e:
+                logger.warning(f"FisherMatrixProposal failed with {e}")
+                if hasattr(self, "iFIM") is False:
+                    # No past iFIM exists, return sample
+                    return sample, 0
+            self.steps_since_update = 0
+
+        jump = self.scale * np.random.multivariate_normal(
+            self.mean, self.iFIM, check_valid="ignore"
+        )
+
+        for key, val in zip(self.parameters, jump):
+            sample[key] += val
+
+        log_factor = 0
+        self.steps_since_update += 1
+        return sample, log_factor
+
+
 class BaseGravitationalWaveTransientProposal(BaseProposal):
     def __init__(self, priors, weight=1):
         super(BaseGravitationalWaveTransientProposal, self).__init__(
@@ -985,7 +1059,7 @@ def get_default_ensemble_proposal_cycle(priors):
 def get_proposal_cycle(string, priors, L1steps=1, warn=True):
     big_weight = 10
     small_weight = 5
-    tiny_weight = 0.1
+    tiny_weight = 0.5
 
     if "gwA" in string:
         # Parameters for learning proposals
@@ -1009,15 +1083,15 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
         if priors.intrinsic:
             intrinsic = PARAMETER_SETS["intrinsic"]
             plist += [
-                AdaptiveGaussianProposal(priors, weight=big_weight, subset=intrinsic),
+                AdaptiveGaussianProposal(priors, weight=small_weight, subset=intrinsic),
                 DifferentialEvolutionProposal(
-                    priors, weight=big_weight, subset=intrinsic
+                    priors, weight=small_weight, subset=intrinsic
                 ),
                 KDEProposal(
-                    priors, weight=big_weight, subset=intrinsic, **learning_kwargs
+                    priors, weight=small_weight, subset=intrinsic, **learning_kwargs
                 ),
                 GMMProposal(
-                    priors, weight=big_weight, subset=intrinsic, **learning_kwargs
+                    priors, weight=small_weight, subset=intrinsic, **learning_kwargs
                 ),
             ]
 
@@ -1026,13 +1100,13 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
             plist += [
                 AdaptiveGaussianProposal(priors, weight=small_weight, subset=extrinsic),
                 DifferentialEvolutionProposal(
-                    priors, weight=big_weight, subset=extrinsic
+                    priors, weight=small_weight, subset=extrinsic
                 ),
                 KDEProposal(
-                    priors, weight=big_weight, subset=extrinsic, **learning_kwargs
+                    priors, weight=small_weight, subset=extrinsic, **learning_kwargs
                 ),
                 GMMProposal(
-                    priors, weight=big_weight, subset=extrinsic, **learning_kwargs
+                    priors, weight=small_weight, subset=extrinsic, **learning_kwargs
                 ),
             ]
 
@@ -1043,6 +1117,11 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
                 GMMProposal(
                     priors, weight=small_weight, subset=mass, **learning_kwargs
                 ),
+                FisherMatrixProposal(
+                    priors,
+                    weight=small_weight,
+                    subset=mass,
+                ),
             ]
 
         if priors.spin:
@@ -1052,13 +1131,23 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
                 GMMProposal(
                     priors, weight=small_weight, subset=spin, **learning_kwargs
                 ),
+                FisherMatrixProposal(
+                    priors,
+                    weight=big_weight,
+                    subset=spin,
+                ),
             ]
-        if priors.precession:
-            measured_spin = ["chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane"]
+        if priors.measured_spin:
+            measured_spin = PARAMETER_SETS["measured_spin"]
             plist += [
                 AdaptiveGaussianProposal(
                     priors, weight=small_weight, subset=measured_spin
                 ),
+                FisherMatrixProposal(
+                    priors,
+                    weight=small_weight,
+                    subset=measured_spin,
+                ),
             ]
 
         if priors.mass and priors.spin:
@@ -1086,6 +1175,21 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
                 CorrelatedPolarisationPhaseJump(priors, weight=tiny_weight),
                 PhasePolarisationReversalProposal(priors, weight=tiny_weight),
             ]
+        if priors.sky:
+            sky = PARAMETER_SETS["sky"]
+            plist += [
+                FisherMatrixProposal(
+                    priors,
+                    weight=small_weight,
+                    subset=sky,
+                ),
+                GMMProposal(
+                    priors,
+                    weight=small_weight,
+                    subset=sky,
+                    **learning_kwargs,
+                ),
+            ]
         for key in ["time_jitter", "psi", "phi_12", "tilt_2", "lambda_1", "lambda_2"]:
             if key in priors.non_fixed_keys:
                 plist.append(PriorProposal(priors, subset=[key], weight=tiny_weight))
@@ -1101,6 +1205,7 @@ def get_proposal_cycle(string, priors, L1steps=1, warn=True):
             DifferentialEvolutionProposal(priors, weight=big_weight),
             UniformProposal(priors, weight=tiny_weight),
             KDEProposal(priors, weight=big_weight, scale_fits=L1steps),
+            FisherMatrixProposal(priors, weight=big_weight),
         ]
         if GMMProposal.check_dependencies(warn=warn):
             plist.append(GMMProposal(priors, weight=big_weight, scale_fits=L1steps))
@@ -1120,6 +1225,7 @@ def remove_proposals_using_string(plist, string):
         GM=GMMProposal,
         PR=PriorProposal,
         UN=UniformProposal,
+        FM=FisherMatrixProposal,
     )
 
     for element in string.split("no")[1:]:
diff --git a/bilby/bilby_mcmc/sampler.py b/bilby/bilby_mcmc/sampler.py
index 45f9dcf84..efa2fcb85 100644
--- a/bilby/bilby_mcmc/sampler.py
+++ b/bilby/bilby_mcmc/sampler.py
@@ -146,6 +146,7 @@ class Bilby_MCMC(MCMCSampler):
         L1steps=100,
         L2steps=3,
         printdt=60,
+        check_point_delta_t=1800,
         min_tau=1,
         proposal_cycle="default",
         stop_after_convergence=False,
@@ -165,7 +166,6 @@ class Bilby_MCMC(MCMCSampler):
         use_ratio=False,
         skip_import_verification=True,
         check_point_plot=True,
-        check_point_delta_t=1800,
         diagnostic=False,
         resume=True,
         exit_code=130,
@@ -202,9 +202,9 @@ class Bilby_MCMC(MCMCSampler):
         self.initial_sample_dict = self.kwargs["initial_sample_dict"]
 
         self.printdt = self.kwargs["printdt"]
+        self.check_point_delta_t = self.kwargs["check_point_delta_t"]
         check_directory_exists_and_if_not_mkdir(self.outdir)
         self.resume = resume
-        self.check_point_delta_t = check_point_delta_t
         self.resume_file = "{}/{}_resume.pickle".format(self.outdir, self.label)
 
         self.verify_configuration()
@@ -224,6 +224,10 @@ class Bilby_MCMC(MCMCSampler):
             for equiv in self.npool_equiv_kwargs:
                 if equiv in kwargs:
                     kwargs["npool"] = kwargs.pop(equiv)
+        if "check_point_delta_t" not in kwargs:
+            for equiv in self.check_point_equiv_kwargs:
+                if equiv in kwargs:
+                    kwargs["check_point_delta_t"] = kwargs.pop(equiv)
 
     @property
     def target_nsamples(self):
@@ -315,8 +319,8 @@ class Bilby_MCMC(MCMCSampler):
         self._steps_since_last_print = 0
         self._time_since_last_print = 0
         logger.info(f"Drawing {self.target_nsamples} samples")
-        logger.info(f"Checkpoint every {self.check_point_delta_t}s")
-        logger.info(f"Print update every {self.printdt}s")
+        logger.info(f"Checkpoint every check_point_delta_t={self.check_point_delta_t}s")
+        logger.info(f"Print update every printdt={self.printdt}s")
 
         while True:
             t0 = datetime.datetime.now()
@@ -390,7 +394,6 @@ class Bilby_MCMC(MCMCSampler):
                 )
                 raise ResumeError(msg)
             self.ptsampler.set_convergence_inputs(self.convergence_inputs)
-            self.ptsampler.proposal_cycle = self.proposal_cycle
             self.ptsampler.pt_rejection_sample = self.pt_rejection_sample
 
         logger.info(
@@ -734,13 +737,19 @@ class BilbyPTMCMCSampler(object):
 
     @property
     def samples(self):
+        cached_samples = getattr(self, "_cached_samples", (False,))
+        if cached_samples[0] == self.position:
+            return cached_samples[1]
+
         sample_list = []
         for sampler in self.zerotemp_sampler_list:
             sample_list.append(sampler.samples)
         if self.pt_rejection_sample:
             for sampler in self.tempered_sampler_list:
                 sample_list.append(sampler.samples)
-        return pd.concat(sample_list)
+        samples = pd.concat(sample_list, ignore_index=True)
+        self._cached_samples = (self.position, samples)
+        return samples
 
     @property
     def position(self):
@@ -783,7 +792,7 @@ class BilbyPTMCMCSampler(object):
 
     @staticmethod
     def _get_sample_to_swap(sampler):
-        if sampler.chain.converged is False:
+        if not (sampler.chain.converged and sampler.stop_after_convergence):
             v = sampler.chain[-1]
         else:
             v = sampler.chain.random_sample
@@ -897,7 +906,7 @@ class BilbyPTMCMCSampler(object):
             ln_z, ln_z_err = self.compute_evidence_per_ensemble(method, kwargs)
             self.ln_z_dict[key] = ln_z
             self.ln_z_err_dict[key] = ln_z_err
-            logger.info(
+            logger.debug(
                 f"Log-evidence of {ln_z:0.2f}+/-{ln_z_err:0.2f} calculated using {key} method"
             )
 
@@ -1141,7 +1150,9 @@ class BilbyMCMCSampler(object):
         if initial_sample_dict is not None:
             initial_sample.update(initial_sample_dict)
 
-        logger.info(f"Using initial sample {initial_sample}")
+        if self.beta == 1:
+            logger.info(f"Using initial sample {initial_sample}")
+
         initial_sample = Sample(initial_sample)
         initial_sample[LOGLKEY] = self.log_likelihood(initial_sample)
         initial_sample[LOGPKEY] = self.log_prior(initial_sample)
@@ -1216,7 +1227,11 @@ class BilbyMCMCSampler(object):
         while internal_steps < self.chain.L1steps:
             internal_steps += 1
             proposal = self.proposal_cycle.get_proposal()
-            prop, log_factor = proposal(self.chain)
+            prop, log_factor = proposal(
+                self.chain,
+                likelihood=_sampling_convenience_dump.likelihood,
+                priors=_sampling_convenience_dump.priors,
+            )
             logp = self.log_prior(prop)
 
             if np.isinf(logp) or np.isnan(logp):
diff --git a/bilby/core/__init__.py b/bilby/core/__init__.py
index 7446bd24f..968f961d0 100644
--- a/bilby/core/__init__.py
+++ b/bilby/core/__init__.py
@@ -1 +1 @@
-from . import grid, likelihood, prior, result, sampler, series, utils
+from . import grid, likelihood, prior, result, sampler, series, utils, fisher
diff --git a/bilby/core/fisher.py b/bilby/core/fisher.py
new file mode 100644
index 000000000..fd452f887
--- /dev/null
+++ b/bilby/core/fisher.py
@@ -0,0 +1,173 @@
+import numpy as np
+import scipy.linalg
+
+import pandas as pd
+from scipy.optimize import minimize
+
+
+class FisherMatrixPosteriorEstimator(object):
+    def __init__(self, likelihood, priors, parameters=None, fd_eps=1e-6, n_prior_samples=100):
+        """ A class to estimate posteriors using the Fisher Matrix approach
+
+        Parameters
+        ----------
+        likelihood: bilby.core.likelihood.Likelihood
+            A bilby likelihood object
+        priors: bilby.core.prior.PriorDict
+            A bilby prior object
+        parameters: list
+            Names of parameters to sample in
+        fd_eps: float
+            A parameter to control the size of perturbation used when finite
+            differencing the likelihood
+        n_prior_samples: int
+            The number of prior samples to draw and use to attempt estimatation
+            of the maximum likelihood sample.
+        """
+        self.likelihood = likelihood
+        if parameters is None:
+            self.parameter_names = priors.non_fixed_keys
+        else:
+            self.parameter_names = parameters
+        self.fd_eps = fd_eps
+        self.n_prior_samples = n_prior_samples
+        self.N = len(self.parameter_names)
+
+        self.prior_samples = [
+            priors.sample_subset(self.parameter_names) for _ in range(n_prior_samples)
+        ]
+        self.prior_bounds = [(priors[key].minimum, priors[key].maximum) for key in self.parameter_names]
+
+        self.prior_width_dict = {}
+        for key in self.parameter_names:
+            width = priors[key].width
+            if np.isnan(width):
+                raise ValueError(f"Prior width is ill-formed for {key}")
+            self.prior_width_dict[key] = width
+
+    def log_likelihood(self, sample):
+        self.likelihood.parameters.update(sample)
+        return self.likelihood.log_likelihood()
+
+    def calculate_iFIM(self, sample):
+        FIM = self.calculate_FIM(sample)
+        iFIM = scipy.linalg.inv(FIM)
+
+        # Ensure iFIM is positive definite
+        min_eig = np.min(np.real(np.linalg.eigvals(iFIM)))
+        if min_eig < 0:
+            iFIM -= 10 * min_eig * np.eye(*iFIM.shape)
+
+        return iFIM
+
+    def sample_array(self, sample, n=1):
+        if sample == "maxL":
+            sample = self.get_maximum_likelihood_sample()
+
+        self.mean = np.array(list(sample.values()))
+        self.iFIM = self.calculate_iFIM(sample)
+        return np.random.multivariate_normal(self.mean, self.iFIM, n)
+
+    def sample_dataframe(self, sample, n=1):
+        samples = self.sample_array(sample, n)
+        return pd.DataFrame(samples, columns=self.parameter_names)
+
+    def calculate_FIM(self, sample):
+        FIM = np.zeros((self.N, self.N))
+        for ii, ii_key in enumerate(self.parameter_names):
+            for jj, jj_key in enumerate(self.parameter_names):
+                FIM[ii, jj] = -self.get_second_order_derivative(sample, ii_key, jj_key)
+
+        return FIM
+
+    def get_second_order_derivative(self, sample, ii, jj):
+        if ii == jj:
+            return self.get_finite_difference_xx(sample, ii)
+        else:
+            return self.get_finite_difference_xy(sample, ii, jj)
+
+    def get_finite_difference_xx(self, sample, ii):
+        # Sample grid
+        p = self.shift_sample_x(sample, ii, 1)
+        m = self.shift_sample_x(sample, ii, -1)
+
+        dx = .5 * (p[ii] - m[ii])
+
+        loglp = self.log_likelihood(p)
+        logl = self.log_likelihood(sample)
+        loglm = self.log_likelihood(m)
+
+        return (loglp - 2 * logl + loglm) / dx ** 2
+
+    def get_finite_difference_xy(self, sample, ii, jj):
+        # Sample grid
+        pp = self.shift_sample_xy(sample, ii, 1, jj, 1)
+        pm = self.shift_sample_xy(sample, ii, 1, jj, -1)
+        mp = self.shift_sample_xy(sample, ii, -1, jj, 1)
+        mm = self.shift_sample_xy(sample, ii, -1, jj, -1)
+
+        dx = .5 * (pp[ii] - mm[ii])
+        dy = .5 * (pp[jj] - mm[jj])
+
+        loglpp = self.log_likelihood(pp)
+        loglpm = self.log_likelihood(pm)
+        loglmp = self.log_likelihood(mp)
+        loglmm = self.log_likelihood(mm)
+
+        return (loglpp - loglpm - loglmp + loglmm) / (4 * dx * dy)
+
+    def shift_sample_x(self, sample, x_key, x_coef):
+
+        vx = sample[x_key]
+        dvx = self.fd_eps * self.prior_width_dict[x_key]
+
+        shift_sample = sample.copy()
+        shift_sample[x_key] = vx + x_coef * dvx
+
+        return shift_sample
+
+    def shift_sample_xy(self, sample, x_key, x_coef, y_key, y_coef):
+
+        vx = sample[x_key]
+        vy = sample[y_key]
+
+        dvx = self.fd_eps * self.prior_width_dict[x_key]
+        dvy = self.fd_eps * self.prior_width_dict[y_key]
+
+        shift_sample = sample.copy()
+        shift_sample[x_key] = vx + x_coef * dvx
+        shift_sample[y_key] = vy + y_coef * dvy
+        return shift_sample
+
+    def get_maximum_likelihood_sample(self, initial_sample=None):
+        """ A method to attempt optimization of the maximum likelihood
+
+        This uses a simple scipy optimization approach, starting from a number
+        of draws from the prior to avoid problems with local optimization.
+
+        Note: this approach works well in small numbers of dimensions when the
+        posterior is narrow relative to the prior. But, if the number of dimensions
+        is large or the posterior is wide relative to the prior, the method fails
+        to find the global maximum in high dimensional problems.
+        """
+        minlogL = np.inf
+        for i in range(self.n_prior_samples):
+            initial_sample = self.prior_samples[i]
+
+            x0 = list(initial_sample.values())
+
+            def neg_log_like(x, self, T=1):
+                sample = {key: val for key, val in zip(self.parameter_names, x)}
+                return - 1 / T * self.log_likelihood(sample)
+
+            out = minimize(
+                neg_log_like,
+                x0,
+                args=(self, 1),
+                bounds=self.prior_bounds,
+                method="L-BFGS-B",
+            )
+            if out.fun < minlogL:
+                minout = out
+
+        return {key: val for key, val in zip(self.parameter_names, minout.x)}
diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index e7ad043b3..32bde7c78 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -202,6 +202,7 @@ class Sampler(object):
     If a specific sampler does not have a sampling seed option, then it should be
     left as None.
     """
+    check_point_equiv_kwargs = ["check_point_deltaT", "check_point_delta_t"]
 
     def __init__(
         self,
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 28a0e0230..455296e55 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -752,8 +752,21 @@ class CBCPriorDict(ConditionalPriorDict):
             return None
 
     def is_nonempty_intersection(self, pset):
-        """ Check if keys in self exist in the PARAMETER_SETS pset """
-        if len(PARAMETER_SETS[pset].intersection(self.non_fixed_keys)) > 0:
+        """ Check if keys in self exist in the parameter set
+
+        Parameters
+        ----------
+        pset: str, set
+            Either a string referencing a parameter set in PARAMETER_SETS or
+            a set of keys
+        """
+        if isinstance(pset, str) and pset in PARAMETER_SETS:
+            check_set = PARAMETER_SETS[pset]
+        elif isinstance(pset, set):
+            check_set = pset
+        else:
+            raise ValueError(f"pset {pset} not understood")
+        if len(check_set.intersection(self.non_fixed_keys)) > 0:
             return True
         else:
             return False
@@ -768,6 +781,11 @@ class CBCPriorDict(ConditionalPriorDict):
         """ Return true if priors include any precession parameters """
         return self.is_nonempty_intersection("precession_only")
 
+    @property
+    def measured_spin(self):
+        """ Return true if priors include any measured_spin parameters """
+        return self.is_nonempty_intersection("measured_spin")
+
     @property
     def intrinsic(self):
         """ Return true if priors include any intrinsic parameters """
@@ -778,6 +796,16 @@ class CBCPriorDict(ConditionalPriorDict):
         """ Return true if priors include any extrinsic parameters """
         return self.is_nonempty_intersection("extrinsic")
 
+    @property
+    def sky(self):
+        """ Return true if priors include any extrinsic parameters """
+        return self.is_nonempty_intersection("sky")
+
+    @property
+    def distance_inclination(self):
+        """ Return true if priors include any extrinsic parameters """
+        return self.is_nonempty_intersection("distance_inclination")
+
     @property
     def mass(self):
         """ Return true if priors include any mass parameters """
diff --git a/bilby/gw/source.py b/bilby/gw/source.py
index a3e5cff6a..aa0ad5bbb 100644
--- a/bilby/gw/source.py
+++ b/bilby/gw/source.py
@@ -1081,10 +1081,21 @@ extrinsic = {
     "cos_theta_jn", "geocent_time", "time_jitter", "ra", "dec",
     "H1_time", "L1_time", "V1_time",
 }
+sky = {
+    "azimuth", "zenith", "ra", "dec",
+}
+distance_inclination = {
+    "luminosity_distance", "redshift", "theta_jn", "cos_theta_jn",
+}
+measured_spin = {
+    "chi_1", "chi_2", "a_1", "a_2", "chi_1_in_plane"
+}
 
 PARAMETER_SETS = dict(
     spin=spin, mass=mass, phase=phase, extrinsic=extrinsic,
     tidal=tidal, primary_spin_and_q=primary_spin_and_q,
     intrinsic=spin.union(mass).union(phase).union(tidal),
     precession_only=precession_only,
+    sky=sky, distance_inclination=distance_inclination,
+    measured_spin=measured_spin,
 )
diff --git a/examples/core_examples/linear_regression_with_Fisher.py b/examples/core_examples/linear_regression_with_Fisher.py
new file mode 100644
index 000000000..d17ddb25c
--- /dev/null
+++ b/examples/core_examples/linear_regression_with_Fisher.py
@@ -0,0 +1,71 @@
+#!/usr/bin/env python
+"""
+An example of how to use bilby to perform parameter estimation for
+non-gravitational wave data. In this case, fitting a linear function to
+data with background Gaussian noise. We then compare the result to posteriors
+estimated using the Fisher Information Matrix approximation.
+
+"""
+import copy
+
+import bilby
+import numpy as np
+
+# A few simple setup steps
+outdir = "outdir"
+
+np.random.seed(123)
+
+
+# First, we define our "signal model", in this case a simple linear function
+def model(time, m, c):
+    return time * m + c
+
+
+# Now we define the injection parameters which we make simulated data with
+injection_parameters = dict(m=0.5, c=0.2)
+
+# For this example, we'll use standard Gaussian noise
+
+# These lines of code generate the fake data. Note the ** just unpacks the
+# contents of the injection_parameters when calling the model function.
+sampling_frequency = 10
+time_duration = 10
+time = np.arange(0, time_duration, 1 / sampling_frequency)
+N = len(time)
+sigma = np.random.normal(1, 0.01, N)
+data = model(time, **injection_parameters) + np.random.normal(0, sigma, N)
+
+# Now lets instantiate a version of our GaussianLikelihood, giving it
+# the time, data and signal model
+likelihood = bilby.likelihood.GaussianLikelihood(time, data, model, sigma)
+
+# From hereon, the syntax is exactly equivalent to other bilby examples
+# We make a prior
+priors = dict()
+priors["m"] = bilby.core.prior.Uniform(0, 5, "m")
+priors["c"] = bilby.core.prior.Uniform(-2, 2, "c")
+priors = bilby.core.prior.PriorDict(priors)
+
+# And run sampler
+result = bilby.run_sampler(
+    likelihood=likelihood,
+    priors=priors,
+    sampler="dynesty",
+    nlive=1000,
+    injection_parameters=injection_parameters,
+    outdir=outdir,
+    label="Nested Sampling",
+)
+
+# Finally plot a corner plot: all outputs are stored in outdir
+result.plot_corner()
+
+fim = bilby.core.fisher.FisherMatrixPosteriorEstimator(likelihood, priors)
+result_fim = copy.deepcopy(result)
+result_fim.posterior = fim.sample_dataframe("maxL", 10000)
+result_fim.label = "Fisher"
+
+bilby.core.result.plot_multiple(
+    [result, result_fim], parameters=injection_parameters, truth_color="k"
+)
diff --git a/test/bilby_mcmc/test_proposals.py b/test/bilby_mcmc/test_proposals.py
index 84042d2d4..37fa0a0fe 100644
--- a/test/bilby_mcmc/test_proposals.py
+++ b/test/bilby_mcmc/test_proposals.py
@@ -129,6 +129,8 @@ class TestProposals(TestBaseProposals):
 
     def proposal_check(self, prop, ndim=2, N=100):
         chain = self.create_chain(ndim=ndim)
+        if getattr(prop, 'needs_likelihood_and_priors', False):
+            return
 
         print(f"Testing {prop.__class__.__name__}")
         # Timing and return type
-- 
GitLab