diff --git a/requirements.txt b/requirements.txt index 1b7318e6cb811a64c33f6e8b1b55c53e36bfecc7..636cb022bbb446aa2d3eb86133f7b5fdc72d9fb9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ NRSur7dq2 nestle deepdish ptemcee +mock diff --git a/test/prior_tests.py b/test/prior_tests.py index 2575854cc00c19fc7aed126ff402e2ff1561bd83..6bc4853296cc826c978dec8a1607aa65eb8482eb 100644 --- a/test/prior_tests.py +++ b/test/prior_tests.py @@ -1,5 +1,6 @@ from context import tupak import unittest +from mock import Mock import numpy as np @@ -153,5 +154,41 @@ class TestPriorClasses(unittest.TestCase): self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3) +class TestFillPrior(unittest.TestCase): + + def setUp(self): + self.likelihood = Mock() + self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1) + self.priors = dict(a=1, b=1.1, c='string', d=tupak.prior.Uniform(0, 1)) + self.priors = tupak.prior.fill_priors(self.priors, self.likelihood) + + def tearDown(self): + del self.likelihood + del self.priors + + def test_prior_instances_are_not_changed_by_parsing(self): + self.assertIsInstance(self.priors['d'], tupak.prior.Uniform) + + def test_parsing_ints_to_delta_priors_class(self): + self.assertIsInstance(self.priors['a'], tupak.prior.DeltaFunction) + + def test_parsing_ints_to_delta_priors_with_right_value(self): + self.assertEqual(self.priors['a'].peak, 1) + + def test_parsing_floats_to_delta_priors_class(self): + self.assertIsInstance(self.priors['b'], tupak.prior.DeltaFunction) + + def test_parsing_floats_to_delta_priors_with_right_value(self): + self.assertAlmostEqual(self.priors['b'].peak, 1.1, 1e-8) + + def test_without_available_default_priors_no_prior_is_set(self): + with self.assertRaises(KeyError): + print(self.priors['asdf']) + + def test_with_available_default_priors_a_default_prior_is_set(self): + self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform) + + + if __name__ == '__main__': unittest.main() diff --git a/tupak/sampler.py b/tupak/sampler.py index 0f965bb4eee0d8214012bf02c39f206fd2817781..66559f2e237d3828d81842414b94ad271d9ffbf4 100644 --- a/tupak/sampler.py +++ b/tupak/sampler.py @@ -45,12 +45,12 @@ class Sampler(object): self.outdir = outdir self.use_ratio = use_ratio self.external_sampler = external_sampler + self.external_sampler_function = None self.__search_parameter_keys = [] self.__fixed_parameter_keys = [] self.initialise_parameters() self.verify_parameters() - self.ndim = len(self.__search_parameter_keys) self.kwargs = kwargs self.result = result @@ -88,6 +88,17 @@ class Sampler(object): def fixed_parameter_keys(self): return self.__fixed_parameter_keys + @property + def ndim(self): + return len(self.__search_parameter_keys) + + @property + def kwargs(self): + return self.__kwargs + + @kwargs.setter + def kwargs(self, kwargs): + self.__kwargs = kwargs @property def external_sampler(self): @@ -107,14 +118,6 @@ class Sampler(object): raise TypeError('sampler must either be a string referring to built in sampler or a custom made class that ' 'inherits from sampler') - @property - def kwargs(self): - return self.__kwargs - - @kwargs.setter - def kwargs(self, kwargs): - self.__kwargs = kwargs - def verify_kwargs_against_external_sampler_function(self): args = inspect.getargspec(self.external_sampler_function).args bad_keys = [] @@ -453,7 +456,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', else: result.log_bayes_factor = result.logz - result.noise_logz result.injection_parameters = injection_parameters - result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)] + result.fixed_parameter_keys = sampler.fixed_parameter_keys result.priors = priors result.kwargs = sampler.kwargs result.samples_to_data_frame()