Skip to content
Snippets Groups Projects
Commit a4e681cc authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Moritz Huebner: Introduced a __repr__format_helper method to help automating...

Moritz Huebner: Introduced a __repr__format_helper method to help automating __repr__ formatting for all subclasses.
parent 8aa8546f
No related branches found
No related tags found
No related merge requests found
......@@ -71,9 +71,29 @@ class Prior(object):
def __repr__(self):
prior_name = self.__class__.__name__
prior_args = ', '.join(
['{}={}'.format(k, v) for k, v in self.__dict__.items()])
return "{}({})".format(prior_name, prior_args)
keys = ['name', '_Prior__latex_label', '_Interped__minimum', '_Interped__maximum']
names = ['name', 'latex_label', 'minimum', 'maximum']
args = self.__repr__format_helper(keys, names)
return "{}({})".format(prior_name, args)
def __repr__format_helper(self, keys, names):
string_keys = []
string_names = []
non_string_keys = []
non_string_names = []
for key, name in zip(keys, names):
if isinstance(self.__dict__[key], str): #TODO: check compatibility with Python 2
string_keys.append(key)
string_names.append(name)
else:
non_string_keys.append(key)
non_string_names.append(name)
args = ', '.join(['{}={}'.format(name, '\"' + self.__dict__[key] + '\"')
for key, name in zip(string_keys, string_names)])
args = args + ', ' + ', '.join(['{}={}'.format(name, self.__dict__[key])
for key, name in zip(non_string_keys, non_string_names)])
return args
@property
def is_fixed(self):
......@@ -175,8 +195,8 @@ class PowerLaw(Prior):
def lnprob(self, val):
in_prior = (val >= self.minimum) & (val <= self.maximum)
normalising = (1+self.alpha)/(self.maximum ** (1 + self.alpha)
- self.minimum ** (1 + self.alpha))
normalising = (1 + self.alpha) / (self.maximum ** (1 + self.alpha)
- self.minimum ** (1 + self.alpha))
return self.alpha * np.log(val) * np.log(normalising) * in_prior
......@@ -194,7 +214,7 @@ class LogUniform(PowerLaw):
def __init__(self, minimum, maximum, name=None, latex_label=None):
Prior.__init__(self, name, latex_label, minimum, maximum)
self.alpha = -1
if self.minimum<=0:
if self.minimum <= 0:
logging.warning('You specified a uniform-in-log prior with minimum={}'.format(self.minimum))
......@@ -254,14 +274,14 @@ class Gaussian(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
return self.mu + erfinv(2 * val - 1) * 2**0.5 * self.sigma
return self.mu + erfinv(2 * val - 1) * 2 ** 0.5 * self.sigma
def prob(self, val):
"""Return the prior probability of val"""
return np.exp(-(self.mu - val)**2 / (2 * self.sigma**2)) / (2 * np.pi)**0.5 / self.sigma
return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (2 * np.pi) ** 0.5 / self.sigma
def lnprob(self, val):
return -0.5*((self.mu - val)**2 / self.sigma**2 + np.log(2 * np.pi * self.sigma**2))
return -0.5 * ((self.mu - val) ** 2 / self.sigma ** 2 + np.log(2 * np.pi * self.sigma ** 2))
class TruncatedGaussian(Prior):
......@@ -296,7 +316,7 @@ class TruncatedGaussian(Prior):
"""Return the prior probability of val"""
in_prior = (val >= self.minimum) & (val <= self.maximum)
return np.exp(-(self.mu - val) ** 2 / (2 * self.sigma ** 2)) / (
2 * np.pi) ** 0.5 / self.sigma / self.normalisation * in_prior
2 * np.pi) ** 0.5 / self.sigma / self.normalisation * in_prior
class Interped(Prior):
......@@ -328,11 +348,9 @@ class Interped(Prior):
return rescaled
def __repr__(self):
prior_name = self.__class__.__name__
prior_args = ', '.join(
['{}={}'.format(name, self.__dict__[key]) for key, name in zip(['xx', 'yy', 'name', '_Prior__latex_label'],
['xx', 'yy', 'name', 'latex_label'])])
return "{}({})".format(prior_name, prior_args)
super_repr = Prior.__repr__(self).rstrip(')').__add__(',')
args = ', '.join(['{}={}'.format(name, self.__dict__[key]) for key, name in zip(['xx', 'yy'], ['xx', 'yy'])])
return super_repr + args + ")"
@property
def minimum(self):
......@@ -497,7 +515,7 @@ def fill_priors(prior, likelihood):
logging.warning(
"Parameter {} has no default prior and is set to {}, this will"
" not be sampled and may cause an error."
.format(missing_key, set_val))
.format(missing_key, set_val))
else:
if not test_redundancy(missing_key, prior):
prior[missing_key] = default_prior
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment