Skip to content
Snippets Groups Projects
Commit e4dd231c authored by Colm Talbot's avatar Colm Talbot
Browse files

TEST: add unit test of likelihood reweighting

parent 3362a0a4
No related branches found
No related tags found
No related merge requests found
......@@ -751,5 +751,71 @@ class TestPPPlots(unittest.TestCase):
)
class SimpleGaussianLikelihood(bilby.core.likelihood.Likelihood):
def __init__(self, mean=0, sigma=1):
"""
A very simple Gaussian likelihood for testing
"""
from scipy.stats import norm
super().__init__(parameters=dict())
self.mean = mean
self.sigma = sigma
self.dist = norm(loc=mean, scale=sigma)
def log_likelihood(self):
return self.dist.logpdf(self.parameters["mu"])
class TestReweight(unittest.TestCase):
def setUp(self):
self.priors = bilby.core.prior.PriorDict(dict(
mu=bilby.core.prior.TruncatedNormal(0, 1, minimum=-5, maximum=5),
))
self.result = bilby.core.result.Result(
search_parameter_keys=list(self.priors.keys()),
priors=self.priors,
posterior=pd.DataFrame(self.priors.sample(1000)),
log_evidence=-np.log(10),
)
def _run_reweighting(self, sigma):
likelihood_1 = SimpleGaussianLikelihood()
likelihood_2 = SimpleGaussianLikelihood(sigma=sigma)
original_ln_likelihoods = list()
for ii in range(len(self.result.posterior)):
likelihood_1.parameters = self.result.posterior.iloc[ii]
original_ln_likelihoods.append(likelihood_1.log_likelihood())
self.result.posterior["log_prior"] = self.priors.ln_prob(self.result.posterior)
self.result.posterior["log_likelihood"] = original_ln_likelihoods
self.original_ln_likelihoods = original_ln_likelihoods
return bilby.core.result.reweight(
self.result, likelihood_1, likelihood_2, verbose_output=True
)
def test_reweight_same_likelihood_weights_1(self):
"""
When the likelihoods are the same, the weights should be 1.
"""
_, weights, _, _, _, _ = self._run_reweighting(sigma=1)
self.assertLess(min(abs(weights - 1)), 1e-10)
def test_reweight_different_likelihood_weights_correct(self):
"""
Test the known case where the target likelihood is a Gaussian with
sigma=0.5. The weights can be calculated analytically and the evidence
should be close to the original evidence within statistical error.
"""
from scipy.stats import norm
new, weights, _, _, _, _ = self._run_reweighting(sigma=0.5)
expected_weights = (
norm(0, 0.5).pdf(self.result.posterior["mu"])
/ norm(0, 1).pdf(self.result.posterior["mu"])
)
self.assertLess(min(abs(weights - expected_weights)), 1e-10)
self.assertLess(abs(new.log_evidence - self.result.log_evidence), 0.05)
self.assertNotEqual(new.log_evidence, self.result.log_evidence)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment