Skip to content
Snippets Groups Projects
Commit ad254f2e authored by moritz's avatar moritz
Browse files

Merge branch 'sampler_code_cleanup'

parents c3179287 593fc829
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -13,3 +13,4 @@ NRSur7dq2
nestle
deepdish
ptemcee
mock
from context import tupak
import unittest
from mock import Mock
import numpy as np
......@@ -153,5 +154,41 @@ class TestPriorClasses(unittest.TestCase):
self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
class TestFillPrior(unittest.TestCase):
def setUp(self):
self.likelihood = Mock()
self.likelihood.parameters = dict(a=0, b=0, c=0, d=0, asdf=0, ra=1)
self.priors = dict(a=1, b=1.1, c='string', d=tupak.prior.Uniform(0, 1))
self.priors = tupak.prior.fill_priors(self.priors, self.likelihood)
def tearDown(self):
del self.likelihood
del self.priors
def test_prior_instances_are_not_changed_by_parsing(self):
self.assertIsInstance(self.priors['d'], tupak.prior.Uniform)
def test_parsing_ints_to_delta_priors_class(self):
self.assertIsInstance(self.priors['a'], tupak.prior.DeltaFunction)
def test_parsing_ints_to_delta_priors_with_right_value(self):
self.assertEqual(self.priors['a'].peak, 1)
def test_parsing_floats_to_delta_priors_class(self):
self.assertIsInstance(self.priors['b'], tupak.prior.DeltaFunction)
def test_parsing_floats_to_delta_priors_with_right_value(self):
self.assertAlmostEqual(self.priors['b'].peak, 1.1, 1e-8)
def test_without_available_default_priors_no_prior_is_set(self):
with self.assertRaises(KeyError):
print(self.priors['asdf'])
def test_with_available_default_priors_a_default_prior_is_set(self):
self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform)
if __name__ == '__main__':
unittest.main()
......@@ -45,12 +45,12 @@ class Sampler(object):
self.outdir = outdir
self.use_ratio = use_ratio
self.external_sampler = external_sampler
self.external_sampler_function = None
self.__search_parameter_keys = []
self.__fixed_parameter_keys = []
self.initialise_parameters()
self.verify_parameters()
self.ndim = len(self.__search_parameter_keys)
self.kwargs = kwargs
self.result = result
......@@ -88,6 +88,17 @@ class Sampler(object):
def fixed_parameter_keys(self):
return self.__fixed_parameter_keys
@property
def ndim(self):
return len(self.__search_parameter_keys)
@property
def kwargs(self):
return self.__kwargs
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = kwargs
@property
def external_sampler(self):
......@@ -107,14 +118,6 @@ class Sampler(object):
raise TypeError('sampler must either be a string referring to built in sampler or a custom made class that '
'inherits from sampler')
@property
def kwargs(self):
return self.__kwargs
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = kwargs
def verify_kwargs_against_external_sampler_function(self):
args = inspect.getargspec(self.external_sampler_function).args
bad_keys = []
......@@ -453,7 +456,7 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir',
else:
result.log_bayes_factor = result.logz - result.noise_logz
result.injection_parameters = injection_parameters
result.fixed_parameter_keys = [key for key in priors if isinstance(key, prior.DeltaFunction)]
result.fixed_parameter_keys = sampler.fixed_parameter_keys
result.priors = priors
result.kwargs = sampler.kwargs
result.samples_to_data_frame()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment