Skip to content
Snippets Groups Projects
Commit f67b4961 authored by Moritz's avatar Moritz
Browse files

Some modifications to ConditionalPriorDict after testing

parent 4adcd871
No related branches found
No related tags found
1 merge request!332Resolve "Introduce conditional prior sets"
Pipeline #84110 failed
......@@ -501,10 +501,10 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = []
self._least_recently_rescaled_keys = []
super(ConditionalPriorDict, self).__init__(dictionary=dictionary, filename=filename)
self.resolved = False
self._resolved = False
self._resolve_conditions()
def _resolve_conditions(self, disable_log=False):
def _resolve_conditions(self):
""" Resolves how variables depend on each other and automatically sorts them into the right order """
conditioned_keys_unsorted = [key for key in self.keys() if hasattr(self[key], 'condition_func')]
self._unconditional_keys = [key for key in self.keys() if not hasattr(self[key], 'condition_func')]
......@@ -519,12 +519,10 @@ class ConditionalPriorDict(PriorDict):
self._sorted_keys = self._unconditional_keys.copy()
self._sorted_keys.extend(self.conditional_keys)
self.resolved = True
self._resolved = True
if len(conditioned_keys_unsorted) != 0:
self.resolved = False
if not disable_log:
logger.warning('This set contains unresolvable conditions')
self._resolved = False
def _check_conditions_resolved(self, key, sampled_keys):
""" Checks if all required variables have already been sampled so we can sample this key """
......@@ -537,7 +535,7 @@ class ConditionalPriorDict(PriorDict):
def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
subset_dict = ConditionalPriorDict({key: self[key] for key in keys})
if not subset_dict.resolved:
if not subset_dict._resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
res = dict()
for key in subset_dict.sorted_keys:
......@@ -634,7 +632,7 @@ class ConditionalPriorDict(PriorDict):
self._rescale_indexes = np.append(unconditional_idxs, conditional_idxs)
def _check_resolved(self):
if not self.resolved:
if not self._resolved:
raise IllegalConditionsException("The current set of priors contains unresolveable conditions.")
@property
......@@ -651,7 +649,11 @@ class ConditionalPriorDict(PriorDict):
def __setitem__(self, key, value):
super(ConditionalPriorDict, self).__setitem__(key, value)
self._resolve_conditions(disable_log=True)
self._resolve_conditions()
def __delitem__(self, key):
super(ConditionalPriorDict, self).__delitem__(key)
self._resolve_conditions()
def create_default_prior(name, default_priors_file=None):
......
......@@ -294,6 +294,11 @@ class TestPriorClasses(unittest.TestCase):
outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_least_recently_sampled(self):
for prior in self.priors:
lrs = prior.sample()
self.assertEqual(lrs, prior.least_recently_sampled)
def test_prob_and_ln_prob(self):
for prior in self.priors:
sample = prior.sample()
......@@ -852,7 +857,7 @@ class TestConditionalPrior(unittest.TestCase):
def setUp(self):
self.condition_func_call_counter = 0
def condition_func(reference_parameters, test_parameter_1, test_parameter_2):
def condition_func(reference_parameters, test_variable_1, test_variable_2):
self.condition_func_call_counter += 1
return {key: value + 1 for key, value in reference_parameters.items()}
self.condition_func = condition_func
......@@ -860,14 +865,17 @@ class TestConditionalPrior(unittest.TestCase):
self.maximum = 5
self.test_parameter_1 = 0
self.test_parameter_2 = 1
self.prior = bilby.core.prior.ConditionalPrior(condition_func=condition_func,
minimum=self.minimum,
maximum=self.maximum)
self.prior = bilby.core.prior.ConditionalBasePrior(condition_func=condition_func,
minimum=self.minimum,
maximum=self.maximum)
def tearDown(self):
del self.condition_func
del self.condition_func_call_counter
del self.minimum
del self.maximum
del self.test_parameter_1
del self.test_parameter_2
del self.prior
def test_reference_params(self):
......@@ -876,6 +884,11 @@ class TestConditionalPrior(unittest.TestCase):
def test_required_variables(self):
self.assertListEqual(['test_parameter_1', 'test_parameter_2'], sorted(self.prior.required_variables))
def test_required_variables_no_condition_func(self):
self.prior = bilby.core.prior.ConditionalBasePrior(minimum=self.minimum,
maximum=self.maximum)
self.assertListEqual([], self.prior.required_variables)
def test_get_instantiation_dict(self):
expected = dict(minimum=0, maximum=5, name=None, latex_label=None, unit=None,
boundary=None, condition_func=self.condition_func)
......@@ -937,6 +950,133 @@ class TestConditionalPrior(unittest.TestCase):
self.assertEqual(self.prior.reference_params['minimum'], self.prior.minimum)
self.assertEqual(self.prior.reference_params['maximum'], self.prior.maximum)
def test_cond_prior_instantiation_no_boundary_prior(self):
prior = bilby.core.prior.ConditionalFermiDirac(sigma=1)
self.assertIsNone(prior.boundary)
class TestConditionalPriorDict(unittest.TestCase):
def setUp(self):
def condition_func_1(reference_parameters, var_0):
return reference_parameters
def condition_func_2(reference_parameters, var_0, var_1):
return reference_parameters
def condition_func_3(reference_parameters, var_1, var_2):
return reference_parameters
self.minimum = 0
self.maximum = 1
self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=self.maximum)
self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1,
minimum=self.minimum, maximum=self.maximum)
self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2,
minimum=self.minimum, maximum=self.maximum)
self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3,
minimum=self.minimum, maximum=self.maximum)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
var_0=self.prior_0, var_1=self.prior_1))
self.conditional_priors_manually_set_items = bilby.core.prior.ConditionalPriorDict()
self.test_sample = dict(var_0=0.3, var_1=0.4, var_2=0.5, var_3=0.4)
for key, value in dict(var_0=self.prior_0, var_1=self.prior_1, var_2=self.prior_2, var_3=self.prior_3).items():
self.conditional_priors_manually_set_items[key] = value
def tearDown(self):
del self.minimum
del self.maximum
del self.prior_0
del self.prior_1
del self.prior_2
del self.prior_3
del self.conditional_priors
del self.conditional_priors_manually_set_items
del self.test_sample
def test_conditions_resolved_upon_instantiation(self):
self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'], self.conditional_priors.sorted_keys)
def test_conditions_resolved_setting_items(self):
self.assertListEqual(['var_0', 'var_1', 'var_2', 'var_3'],
self.conditional_priors_manually_set_items.sorted_keys)
def test_unconditional_keys_upon_instantiation(self):
self.assertListEqual(['var_0'], self.conditional_priors.unconditional_keys)
def test_unconditional_keys_setting_items(self):
self.assertListEqual(['var_0'], self.conditional_priors_manually_set_items.unconditional_keys)
def test_conditional_keys_upon_instantiation(self):
self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors.conditional_keys)
def test_conditional_keys_setting_items(self):
self.assertListEqual(['var_1', 'var_2', 'var_3'], self.conditional_priors_manually_set_items.conditional_keys)
def test_prob(self):
self.assertEqual(1, self.conditional_priors.prob(sample=self.test_sample))
def test_prob_illegal_conditions(self):
del self.conditional_priors['var_0']
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.prob(sample=self.test_sample)
def test_ln_prob(self):
self.assertEqual(0, self.conditional_priors.ln_prob(sample=self.test_sample))
def test_ln_prob_illegal_conditions(self):
del self.conditional_priors['var_0']
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.ln_prob(sample=self.test_sample)
def test_sample_subset_all_keys(self):
with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5
self.assertDictEqual(dict(var_0=0.5, var_1=0.5, var_2=0.5, var_3=0.5),
self.conditional_priors.sample_subset(keys=['var_0', 'var_1', 'var_2', 'var_3']))
def test_sample_illegal_subset(self):
with mock.patch("numpy.random.uniform") as m:
m.return_value = 0.5
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.sample_subset(keys=['var_1'])
def test_rescale(self):
def condition_func_1_rescale(reference_parameters, var_0):
if var_0 == 0.5:
return dict(minimum=reference_parameters['minimum'], maximum=1)
return reference_parameters
def condition_func_2_rescale(reference_parameters, var_0, var_1):
if var_0 == 0.5 and var_1 == 0.5:
return dict(minimum=reference_parameters['minimum'], maximum=1)
return reference_parameters
def condition_func_3_rescale(reference_parameters, var_1, var_2):
if var_1 == 0.5 and var_2 == 0.5:
return dict(minimum=reference_parameters['minimum'], maximum=1)
return reference_parameters
self.prior_0 = bilby.core.prior.Uniform(minimum=self.minimum, maximum=1)
self.prior_1 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_1_rescale,
minimum=self.minimum, maximum=2)
self.prior_2 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_2_rescale,
minimum=self.minimum, maximum=2)
self.prior_3 = bilby.core.prior.ConditionalUniform(condition_func=condition_func_3_rescale,
minimum=self.minimum, maximum=2)
self.conditional_priors = bilby.core.prior.ConditionalPriorDict(dict(var_3=self.prior_3, var_2=self.prior_2,
var_0=self.prior_0, var_1=self.prior_1))
ref_variables = [0.5, 0.5, 0.5, 0.5]
res = self.conditional_priors.rescale(keys=list(self.test_sample.keys()),
theta=ref_variables)
self.assertListEqual(ref_variables, res)
def test_rescale_illegal_conditions(self):
del self.conditional_priors['var_0']
with self.assertRaises(bilby.core.prior.IllegalConditionsException):
self.conditional_priors.rescale(keys=list(self.test_sample.keys()), theta=list(self.test_sample.values()))
class TestJsonIO(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment