diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 9eac23a5cdcfd4010abf36faa45def3509c95c55..ae8be2a1194a2b7219349a71fcf6fb8cb15f461d 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -1,4 +1,5 @@ import os +import copy import numpy as np from scipy.interpolate import InterpolatedUnivariateSpline @@ -19,6 +20,70 @@ except ImportError: " not be able to use some of the prebuilt functions.") +class BilbyPriorConversionError(Exception): + pass + + +def convert_to_flat_in_component_mass_prior(result, fraction=0.25): + """ Converts samples flat in chirp-mass and mass-ratio to flat in component mass + + Parameters + ---------- + result: bilby.core.result.Result + The output result complete with priors and posteriors + fraction: float [0, 1] + The fraction of samples to draw (default=0.25). Note, if too high a + fraction of samples are draw, the reweighting will not be applied in + effect. + + """ + if getattr(result, "priors") is not None: + if isinstance(getattr(result.priors, "chirp_mass", None), Uniform) is False: + BilbyPriorConversionError("chirp mass prior should be Uniform") + if isinstance(getattr(result.priors, "mass_ratio", None), Uniform) is False: + BilbyPriorConversionError("mass_ratio prior should be Uniform") + if isinstance(getattr(result.priors, "mass_1", None), Constraint): + BilbyPriorConversionError("mass_1 prior should be a Constraint") + if isinstance(getattr(result.priors, "mass_2", None), Constraint): + BilbyPriorConversionError("mass_2 prior should be a Constraint") + else: + BilbyPriorConversionError("No prior in the result: unable to convert") + + result = copy.copy(result) + priors = result.priors + posterior = result.posterior + + priors["chirp_mass"] = Constraint( + priors["chirp_mass"].minimum, priors["chirp_mass"].maximum, + "chirp_mass", latex_label=priors["chirp_mass"].latex_label) + priors["mass_ratio"] = Constraint( + priors["mass_ratio"].minimum, priors["mass_ratio"].maximum, + "mass_ratio", latex_label=priors["mass_ratio"].latex_label) + priors["mass_1"] = Uniform( + priors["mass_1"].minimum, priors["mass_1"].maximum, "mass_1", + latex_label=priors["mass_1"].latex_label) + priors["mass_1"] = Uniform( + priors["mass_2"].minimum, priors["mass_2"].maximum, "mass_2", + latex_label=priors["mass_2"].latex_label) + + weights = posterior["mass_1"] ** 2 / posterior["chirp_mass"] + result.posterior = posterior.sample(frac=fraction, weights=weights) + + logger.info("Resampling posterior to flat-in-component mass") + effective_sample_size = sum(weights)**2 / sum(weights**2) + n_posterior = len(posterior) + if fraction > effective_sample_size / n_posterior: + logger.warning( + "Sampling posterior of length {} with fraction {}, but " + "effective_sample_size / len(posterior) = {}. This may produce " + "biased results" + .format(n_posterior, fraction, effective_sample_size / n_posterior) + ) + result.posterior = posterior.sample(frac=fraction, weights=weights, replace=True) + result.meta_data["reweighted_to_flat_in_component_mass"] = True + return result + + class Cosmological(Interped): @property diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py index d594d2971f813d5a9d0570eed07da497557ded18..f564ead636214c5f861f73826d03e7c0136fbdbf 100644 --- a/test/gw_prior_test.py +++ b/test/gw_prior_test.py @@ -7,8 +7,14 @@ import pickle import numpy as np from astropy import cosmology +from scipy.stats import ks_2samp +import matplotlib.pyplot as plt +import pandas as pd import bilby +from bilby.core.prior import Uniform, Constraint +from bilby.gw.prior import BBHPriorDict +from bilby.gw import conversion class TestBBHPriorDict(unittest.TestCase): @@ -122,6 +128,67 @@ class TestBBHPriorDict(unittest.TestCase): self.assertEqual(priors, priors_loaded) +class TestPriorConversion(unittest.TestCase): + + def test_bilby_to_lalinference(self): + mass_1 = [1, 20] + mass_2 = [1, 20] + chirp_mass = [1, 5] + mass_ratio = [0.125, 1] + + bilby_prior = BBHPriorDict(dictionary=dict( + chirp_mass=Uniform(name='chirp_mass', minimum=chirp_mass[0], maximum=chirp_mass[1]), + mass_ratio=Uniform(name='mass_ratio', minimum=mass_ratio[0], maximum=mass_ratio[1]), + mass_2=Constraint(name='mass_2', minimum=mass_1[0], maximum=mass_1[1]), + mass_1=Constraint(name='mass_1', minimum=mass_2[0], maximum=mass_2[1]))) + + lalinf_prior = BBHPriorDict(dictionary=dict( + chirp_mass=Constraint(name='chirp_mass', minimum=chirp_mass[0], maximum=chirp_mass[1]), + mass_2=Uniform(name='mass_2', minimum=mass_1[0], maximum=mass_1[1]), + mass_1=Uniform(name='mass_1', minimum=mass_2[0], maximum=mass_2[1]))) + + nsamples = 5000 + bilby_samples = bilby_prior.sample(nsamples) + bilby_samples, _ = conversion.convert_to_lal_binary_black_hole_parameters( + bilby_samples) + + # Quicker way to generate LA prior samples (rather than specifying Constraint) + lalinf_samples = [] + while len(lalinf_samples) < nsamples: + s = lalinf_prior.sample() + if s["mass_1"] < s["mass_2"]: + s["mass_1"], s["mass_2"] = s["mass_2"], s["mass_1"] + if s["mass_2"] / s["mass_1"] > 0.125: + lalinf_samples.append(s) + lalinf_samples = pd.DataFrame(lalinf_samples) + lalinf_samples["mass_ratio"] = lalinf_samples["mass_2"] / lalinf_samples["mass_1"] + + # Construct fake result object + result = bilby.core.result.Result() + result.search_parameter_keys = ["mass_ratio", "chirp_mass"] + result.meta_data = dict() + result.priors = bilby_prior + result.posterior = pd.DataFrame(bilby_samples) + result_converted = bilby.gw.prior.convert_to_flat_in_component_mass_prior(result) + + if "plot" in sys.argv: + # Useful for debugging + plt.hist(bilby_samples["mass_ratio"], bins=50, density=True, alpha=0.5) + plt.hist(result_converted.posterior["mass_ratio"], bins=50, density=True, alpha=0.5) + plt.hist(lalinf_samples["mass_ratio"], bins=50, alpha=0.5, density=True) + plt.show() + + # Check that the non-reweighted posteriors fail a KS test + ks = ks_2samp(bilby_samples["mass_ratio"], lalinf_samples["mass_ratio"]) + print("Non-reweighted KS test = ", ks) + self.assertFalse(ks.pvalue > 0.05) + + # Check that the non-reweighted posteriors pass a KS test + ks = ks_2samp(result_converted.posterior["mass_ratio"], lalinf_samples["mass_ratio"]) + print("Reweighted KS test = ", ks) + self.assertTrue(ks.pvalue > 0.001) + + class TestPackagedPriors(unittest.TestCase): """ Test that the prepackaged priors load """