Skip to content
Snippets Groups Projects
Commit 0cd77bb1 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Moritz Huebner: Removed the property and created a function to initialise the result

parent bb91eb56
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -59,18 +59,7 @@ class Sampler(object):
if os.path.isdir(outdir) is False:
os.makedirs(outdir)
@property
def result(self):
self.__result = Result()
self.__result.search_parameter_keys = self.__search_parameter_keys
self.__result.fixed_parameter_keys = self.__fixed_parameter_keys
self.__result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
self.__result.label = self.label
self.__result.outdir = self.outdir
self.__result.priors = self.priors
self.__result.kwargs = self.kwargs
self.result = self.initialise_result()
@property
def search_parameter_keys(self):
......@@ -117,7 +106,7 @@ class Sampler(object):
if user_input not in args:
logging.warning(
"Supplied argument '{}' not an argument of '{}', removing."
.format(user_input, self.external_sampler_function))
.format(user_input, self.external_sampler_function))
bad_keys.append(user_input)
for key in bad_keys:
self.kwargs.pop(key)
......@@ -140,6 +129,19 @@ class Sampler(object):
for key in self.__fixed_parameter_keys:
logging.info(' {} = {}'.format(key, self.priors[key].peak))
def initialise_result(self):
result = Result()
result.search_parameter_keys = self.__search_parameter_keys
result.fixed_parameter_keys = self.__fixed_parameter_keys
result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
result.label = self.label
result.outdir = self.outdir
result.priors = self.priors
result.kwargs = self.kwargs
return result
def verify_parameters(self):
required_keys = self.priors
unmatched_keys = [r for r in required_keys if r not in self.likelihood.parameters]
......@@ -153,7 +155,7 @@ class Sampler(object):
def log_prior(self, theta):
return np.sum(
[np.log(self.priors[key].prob(t)) for key, t in
zip(self.__search_parameter_keys, theta)])
zip(self.__search_parameter_keys, theta)])
def log_likelihood(self, theta):
for i, k in enumerate(self.__search_parameter_keys):
......@@ -174,7 +176,7 @@ class Sampler(object):
"""
draw = np.array([self.priors[key].sample()
for key in self.__search_parameter_keys])
for key in self.__search_parameter_keys])
if np.isinf(self.log_likelihood(draw)):
logging.info('Prior draw {} has inf likelihood'.format(draw))
if np.isinf(self.log_prior(draw)):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment