Commit 0cd77bb1 authored by Moritz Huebner's avatar Moritz Huebner

Moritz Huebner: Removed the property and created a function to initialise the result

parent bb91eb56
Pipeline #19341 passed with stages
in 7 minutes and 17 seconds
...@@ -59,18 +59,7 @@ class Sampler(object): ...@@ -59,18 +59,7 @@ class Sampler(object):
if os.path.isdir(outdir) is False: if os.path.isdir(outdir) is False:
os.makedirs(outdir) os.makedirs(outdir)
@property self.result = self.initialise_result()
def result(self):
self.__result = Result()
self.__result.search_parameter_keys = self.__search_parameter_keys
self.__result.fixed_parameter_keys = self.__fixed_parameter_keys
self.__result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
self.__result.label = self.label
self.__result.outdir = self.outdir
self.__result.priors = self.priors
self.__result.kwargs = self.kwargs
@property @property
def search_parameter_keys(self): def search_parameter_keys(self):
...@@ -117,7 +106,7 @@ class Sampler(object): ...@@ -117,7 +106,7 @@ class Sampler(object):
if user_input not in args: if user_input not in args:
logging.warning( logging.warning(
"Supplied argument '{}' not an argument of '{}', removing." "Supplied argument '{}' not an argument of '{}', removing."
.format(user_input, self.external_sampler_function)) .format(user_input, self.external_sampler_function))
bad_keys.append(user_input) bad_keys.append(user_input)
for key in bad_keys: for key in bad_keys:
self.kwargs.pop(key) self.kwargs.pop(key)
...@@ -140,6 +129,19 @@ class Sampler(object): ...@@ -140,6 +129,19 @@ class Sampler(object):
for key in self.__fixed_parameter_keys: for key in self.__fixed_parameter_keys:
logging.info(' {} = {}'.format(key, self.priors[key].peak)) logging.info(' {} = {}'.format(key, self.priors[key].peak))
def initialise_result(self):
result = Result()
result.search_parameter_keys = self.__search_parameter_keys
result.fixed_parameter_keys = self.__fixed_parameter_keys
result.parameter_labels = [
self.priors[k].latex_label for k in
self.__search_parameter_keys]
result.label = self.label
result.outdir = self.outdir
result.priors = self.priors
result.kwargs = self.kwargs
return result
def verify_parameters(self): def verify_parameters(self):
required_keys = self.priors required_keys = self.priors
unmatched_keys = [r for r in required_keys if r not in self.likelihood.parameters] unmatched_keys = [r for r in required_keys if r not in self.likelihood.parameters]
...@@ -153,7 +155,7 @@ class Sampler(object): ...@@ -153,7 +155,7 @@ class Sampler(object):
def log_prior(self, theta): def log_prior(self, theta):
return np.sum( return np.sum(
[np.log(self.priors[key].prob(t)) for key, t in [np.log(self.priors[key].prob(t)) for key, t in
zip(self.__search_parameter_keys, theta)]) zip(self.__search_parameter_keys, theta)])
def log_likelihood(self, theta): def log_likelihood(self, theta):
for i, k in enumerate(self.__search_parameter_keys): for i, k in enumerate(self.__search_parameter_keys):
...@@ -174,7 +176,7 @@ class Sampler(object): ...@@ -174,7 +176,7 @@ class Sampler(object):
""" """
draw = np.array([self.priors[key].sample() draw = np.array([self.priors[key].sample()
for key in self.__search_parameter_keys]) for key in self.__search_parameter_keys])
if np.isinf(self.log_likelihood(draw)): if np.isinf(self.log_likelihood(draw)):
logging.info('Prior draw {} has inf likelihood'.format(draw)) logging.info('Prior draw {} has inf likelihood'.format(draw))
if np.isinf(self.log_prior(draw)): if np.isinf(self.log_prior(draw)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment