Skip to content
Snippets Groups Projects
Commit 5c9b5709 authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

Add cdf method to PriorDict classes

parent da290af1
No related branches found
No related tags found
1 merge request!943Add cdf method to PriorDict classes
...@@ -491,6 +491,21 @@ class PriorDict(dict): ...@@ -491,6 +491,21 @@ class PriorDict(dict):
constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio)
return constrained_ln_prob return constrained_ln_prob
def cdf(self, sample):
"""Evaluate the cumulative distribution function at the provided points
Parameters
----------
sample: dict, pandas.DataFrame
Dictionary of the samples of which to calculate the CDF
Returns
-------
dict, pandas.DataFrame: Dictionary containing the CDF values
"""
return sample.__class__({key: self[key].cdf(sample) for key, sample in sample.items()})
def rescale(self, keys, theta): def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior """Rescale samples from unit cube to prior
...@@ -654,9 +669,7 @@ class ConditionalPriorDict(PriorDict): ...@@ -654,9 +669,7 @@ class ConditionalPriorDict(PriorDict):
float: Joint probability of all individual sample probabilities float: Joint probability of all individual sample probabilities
""" """
self._check_resolved() self._prepare_evaluation(*zip(*sample.items()))
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample] res = [self[key].prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.product(res, **kwargs) return np.product(res, **kwargs)
...@@ -675,12 +688,15 @@ class ConditionalPriorDict(PriorDict): ...@@ -675,12 +688,15 @@ class ConditionalPriorDict(PriorDict):
float: Joint log probability of all the individual sample probabilities float: Joint log probability of all the individual sample probabilities
""" """
self._check_resolved() self._prepare_evaluation(*zip(*sample.items()))
for key, value in sample.items():
self[key].least_recently_sampled = value
res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample] res = [self[key].ln_prob(sample[key], **self.get_required_variables(key)) for key in sample]
return np.sum(res, axis=axis) return np.sum(res, axis=axis)
def cdf(self, sample):
self._prepare_evaluation(*zip(*sample.items()))
res = {key: self[key].cdf(sample[key], **self.get_required_variables(key)) for key in sample}
return sample.__class__(res)
def rescale(self, keys, theta): def rescale(self, keys, theta):
"""Rescale samples from unit cube to prior """Rescale samples from unit cube to prior
...@@ -695,12 +711,14 @@ class ConditionalPriorDict(PriorDict): ...@@ -695,12 +711,14 @@ class ConditionalPriorDict(PriorDict):
======= =======
list: List of floats containing the rescaled sample list: List of floats containing the rescaled sample
""" """
keys = list(keys)
theta = list(theta)
self._check_resolved() self._check_resolved()
self._update_rescale_keys(keys) self._update_rescale_keys(keys)
result = dict() result = dict()
for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes): for key, index in zip(self.sorted_keys_without_fixed_parameters, self._rescale_indexes):
required_variables = {k: result[k] for k in getattr(self[key], 'required_variables', [])} result[key] = self[key].rescale(theta[index], **self.get_required_variables(key))
result[key] = self[key].rescale(theta[index], **required_variables) self[key].least_recently_sampled = result[key]
return [result[key] for key in keys] return [result[key] for key in keys]
def _update_rescale_keys(self, keys): def _update_rescale_keys(self, keys):
...@@ -708,6 +726,11 @@ class ConditionalPriorDict(PriorDict): ...@@ -708,6 +726,11 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters] self._rescale_indexes = [keys.index(element) for element in self.sorted_keys_without_fixed_parameters]
self._least_recently_rescaled_keys = keys self._least_recently_rescaled_keys = keys
def _prepare_evaluation(self, keys, theta):
self._check_resolved()
for key, value in zip(keys, theta):
self[key].least_recently_sampled = value
def _check_resolved(self): def _check_resolved(self):
if not self._resolved: if not self._resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.") raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
......
import unittest import unittest
import mock import mock
import numpy as np
import bilby import bilby
...@@ -170,13 +171,13 @@ class TestConditionalPrior(unittest.TestCase): ...@@ -170,13 +171,13 @@ class TestConditionalPrior(unittest.TestCase):
class TestConditionalPriorDict(unittest.TestCase): class TestConditionalPriorDict(unittest.TestCase):
def setUp(self): def setUp(self):
def condition_func_1(reference_parameters, var_0): def condition_func_1(reference_parameters, var_0):
return reference_parameters return dict(minimum=reference_parameters["minimum"], maximum=var_0)
def condition_func_2(reference_parameters, var_0, var_1): def condition_func_2(reference_parameters, var_0, var_1):
return reference_parameters return dict(minimum=reference_parameters["minimum"], maximum=var_1)
def condition_func_3(reference_parameters, var_1, var_2): def condition_func_3(reference_parameters, var_1, var_2):
return reference_parameters return dict(minimum=reference_parameters["minimum"], maximum=var_2)
self.minimum = 0 self.minimum = 0
self.maximum = 1 self.maximum = 1
...@@ -203,7 +204,8 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -203,7 +204,8 @@ class TestConditionalPriorDict(unittest.TestCase):
self.conditional_priors_manually_set_items = ( self.conditional_priors_manually_set_items = (
bilby.core.prior.ConditionalPriorDict() bilby.core.prior.ConditionalPriorDict()
) )
self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4) self.test_sample = dict(var_0=0.7, var_1=0.6, var_2=0.5, var_3=0.4)
self.test_value = 1 / np.prod([self.test_sample[f"var_{ii}"] for ii in range(3)])
for key, value in dict( for key, value in dict(
var_0=self.prior_0, var_0=self.prior_0,
var_1=self.prior_1, var_1=self.prior_1,
...@@ -254,7 +256,7 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -254,7 +256,7 @@ class TestConditionalPriorDict(unittest.TestCase):
) )
def test_prob(self): def test_prob(self):
self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample)) self.assertEqual(self.test_value, self.conditional_priors.prob(sample=self.test_sample))
def test_prob_illegal_conditions(self): def test_prob_illegal_conditions(self):
del self.conditional_priors["var_0"] del self.conditional_priors["var_0"]
...@@ -262,7 +264,7 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -262,7 +264,7 @@ class TestConditionalPriorDict(unittest.TestCase):
self.conditional_priors.prob(sample=self.test_sample) self.conditional_priors.prob(sample=self.test_sample)
def test_ln_prob(self): def test_ln_prob(self):
self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample)) self.assertEqual(np.log(self.test_value), self.conditional_priors.ln_prob(sample=self.test_sample))
def test_ln_prob_illegal_conditions(self): def test_ln_prob_illegal_conditions(self):
del self.conditional_priors["var_0"] del self.conditional_priors["var_0"]
...@@ -273,7 +275,7 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -273,7 +275,7 @@ class TestConditionalPriorDict(unittest.TestCase):
with mock.patch("numpy.random.uniform") as m: with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5 m.return_value = 0.5
self.assertDictEqual( self.assertDictEqual(
dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5), dict(var_0=0.5, var_1=0.5 ** 2, var_2=0.5 ** 3, var_3=0.5 ** 4),
self.conditional_priors.sample_subset( self.conditional_priors.sample_subset(
keys=["var_0", "var_1", "var_2", "var_3"] keys=["var_0", "var_1", "var_2", "var_3"]
), ),
...@@ -301,31 +303,6 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -301,31 +303,6 @@ class TestConditionalPriorDict(unittest.TestCase):
print(priors.sample(2)) print(priors.sample(2))
def test_rescale(self): def test_rescale(self):
def condition_func_1_rescale(reference_parameters, var_0):
if var_0 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_2_rescale(reference_parameters, var_0, var_1):
if var_0 == 0.5 and var_1 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
def condition_func_3_rescale(reference_parameters, var_1, var_2):
if var_1 == 0.5 and var_2 == 0.5:
return dict(minimum=reference_parameters["minimum"], maximum=1)
return reference_parameters
self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
self.prior_1 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_1_rescale, minimum=self.minimum, maximum=2
)
self.prior_2 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_2_rescale, minimum=self.minimum, maximum=2
)
self.prior_3 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_3_rescale, minimum=self.minimum, maximum=2
)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict( self.conditional_priors = bilby.core.prior.ConditionalPriorDict(
dict( dict(
var_3=self.prior_3, var_3=self.prior_3,
...@@ -334,11 +311,28 @@ class TestConditionalPriorDict(unittest.TestCase): ...@@ -334,11 +311,28 @@ class TestConditionalPriorDict(unittest.TestCase):
var_1=self.prior_1, var_1=self.prior_1,
) )
) )
ref_variables = [0.5, 0.5, 0.5, 0.5] ref_variables = self.test_sample.values()
res = self.conditional_priors.rescale( res = self.conditional_priors.rescale(
keys=list(self.test_sample.keys()), theta=ref_variables keys=self.test_sample.keys(), theta=ref_variables
)
expected = [self.test_sample["var_0"]]
for ii in range(1, 4):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res)
def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Note that the format of inputs/outputs is different between the two methods.
"""
sample = self.conditional_priors.sample()
self.assertEqual(
self.conditional_priors.rescale(
sample.keys(),
self.conditional_priors.cdf(sample=sample).values()
), list(sample.values())
) )
self.assertListEqual(ref_variables, res)
def test_rescale_illegal_conditions(self): def test_rescale_illegal_conditions(self):
del self.conditional_priors["var_0"] del self.conditional_priors["var_0"]
......
...@@ -298,6 +298,20 @@ class TestPriorDict(unittest.TestCase): ...@@ -298,6 +298,20 @@ class TestPriorDict(unittest.TestCase):
), ),
) )
def test_cdf(self):
"""
Test that the CDF method is the inverse of the rescale method.
Note that the format of inputs/outputs is different between the two methods.
"""
sample = self.prior_set_from_dict.sample()
self.assertEqual(
self.prior_set_from_dict.rescale(
sample.keys(),
self.prior_set_from_dict.cdf(sample=sample).values()
), list(sample.values())
)
def test_redundancy(self): def test_redundancy(self):
for key in self.prior_set_from_dict.keys(): for key in self.prior_set_from_dict.keys():
self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key)) self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment