Skip to content
Snippets Groups Projects
Commit 95f2a457 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'prior-dict-cdf' into 'master'

Add cdf method to PriorDict classes

See merge request lscsoft/bilby!943
parents c7417e24 5c9b5709
No related branches found
No related tags found
1 merge request!943Add cdf method to PriorDict classes
Pipeline #249533 passed with warnings
......@@ -518,6 +518,21 @@ class PriorDict(dict):
constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
return constrained_ln_prob
def cdf(self, sample):
"""Evaluate the cumulative distribution function at the provided points
Parameters
----------
sample: dict, pandas.DataFrame
Dictionary of the samples of which to calculate the CDF
Returns
-------
dict, pandas.DataFrame: Dictionary containing the CDF values
"""
return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()})
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -681,9 +696,7 @@ class ConditionalPriorDict(PriorDict):
float: Joint probability of all individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
self._prepare_evaluation(*zip(*sample.items()))
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
prob = np.product(res, **kwargs)
return self.check_prob(sample, prob)
......@@ -703,13 +716,16 @@ class ConditionalPriorDict(PriorDict):
float: Joint log probability of all the individual sample probabilities
"""
self._check_resolved()
for key, value in sample.items():
self[key].least_recently_sampled = value
self._prepare_evaluation(*zip(*sample.items()))
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
ln_prob = np.sum(res, axis=axis)
return self.check_ln_prob(sample, ln_prob)
def cdf(self, sample):
self._prepare_evaluation(*zip(*sample.items()))
res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample}
return sample.__class__(res)
def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior
......@@ -724,12 +740,14 @@ class ConditionalPriorDict(PriorDict):
=======
list: List of floats containing the rescaled sample
"""
keys = list(keys)
theta = list(theta)
self._check_resolved()
self._update_rescale_keys(keys)
result = dict()
for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])}
result[key] = self[key].rescale(theta[index], **required_variables)
result[key] = self[key].rescale(theta[index], **self.get_required_variables(key))
self[key].least_recently_sampled = result[key]
return [result[key] for key in keys]
def _update_rescale_keys(self, keys):
......@@ -737,6 +755,11 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
self._least_recently_rescaled_keys = keys
def _prepare_evaluation(self, keys, theta):
self._check_resolved()
for key, value in zip(keys, theta):
self[key].least_recently_sampled = value
def _check_resolved(self):
if not self._resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
......
import unittest
import mock
import numpy as np
import bilby
......@@ -170,13 +171,13 @@ class TestConditionalPrior(unittest.TestCase):
class TestConditionalPriorDict(unittest.TestCase):
def setUp(self):
def condition_func_1(reference_parameters, var_0):
return reference_parameters
return dict(minimum=reference_parameters["minimum"], maximum=var_0)
def condition_func_2(reference_parameters, var_0, var_1):
return reference_parameters
return dict(minimum=reference_parameters["minimum"], maximum=var_1)
def condition_func_3(reference_parameters, var_1, var_2):
return reference_parameters
return dict(minimum=reference_parameters["minimum"], maximum=var_2)
self.minimum = 0
self.maximum = 1
......@@ -203,7 +204,8 @@ class TestConditionalPriorDict(unittest.TestCase):
self.conditional_priors_manually_set_items = (
bilby.core.prior.ConditionalPriorDict()
)
self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4)
self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)])
for key, value in dict(
var_0=self.prior_0,
var_1=self.prior_1,
......@@ -254,7 +256,7 @@ class TestConditionalPriorDict(unittest.TestCase):
)
def test_prob(self):
self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample))
def test_prob_illegal_conditions(self):
del self.conditional_priors["var_0"]
......@@ -262,7 +264,7 @@ class TestConditionalPriorDict(unittest.TestCase):
self.conditional_priors.prob(sample=self.test_sample)
def test_ln_prob(self):
self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample))
def test_ln_prob_illegal_conditions(self):
del self.conditional_priors["var_0"]
......@@ -273,7 +275,7 @@ class TestConditionalPriorDict(unittest.TestCase):
with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5
self.assertDictEqual(
dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
dict(var_0=0.5, var_1=0.5 ** 2, var_2=0.5 ** 3, var_3=0.5 ** 4),
self.conditional_priors.sample_subset(
keys=["var_0", "var_1", "var_2", "var_3"]
),
......@@ -301,31 +303,6 @@ class TestConditionalPriorDict(unittest.TestCase):
print(priors.sample(2))
def test_rescale(self):
def condition_func_1_rescale(reference_parameters, var_0):
if var_0 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_2_rescale(reference_parameters, var_0, var_1):
if var_0 == 0.5 and var_1 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_3_rescale(reference_parameters, var_1, var_2):
if var_1 == 0.5 and var_2 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
self.prior_1 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_1_rescale, minimum=self.minimum, maximum=2
)
self.prior_2 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_2_rescale, minimum=self.minimum, maximum=2
)
self.prior_3 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_3_rescale, minimum=self.minimum, maximum=2
)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict(
dict(
var_3=self.prior_3,
......@@ -334,11 +311,28 @@ class TestConditionalPriorDict(unittest.TestCase):
var_1=self.prior_1,
)
)
ref_variables = [0.5, 0.5, 0.5, 0.5]
ref_variables = self.test_sample.values()
res = self.conditional_priors.rescale(
keys=list(self.test_sample.keys()), theta=ref_variables
keys=self.test_sample.keys(), theta=ref_variables
)
expected = [self.test_sample["var_0"]]
for ii in range(1, 4):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res)
def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Note that the format of inputs/outputs is different between the two methods.
"""
sample = self.conditional_priors.sample()
self.assertEqual(
self.conditional_priors.rescale(
sample.keys(),
self.conditional_priors.cdf(sample=sample).values()
), list(sample.values())
)
self.assertListEqual(ref_variables, res)
def test_rescale_illegal_conditions(self):
del self.conditional_priors["var_0"]
......
......@@ -298,6 +298,20 @@ class TestPriorDict(unittest.TestCase):
),
)
def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Note that the format of inputs/outputs is different between the two methods.
"""
sample = self.prior_set_from_dict.sample()
self.assertEqual(
self.prior_set_from_dict.rescale(
sample.keys(),
self.prior_set_from_dict.cdf(sample=sample).values()
), list(sample.values())
)
def test_redundancy(self):
for key in self.prior_set_from_dict.keys():
self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment