Skip to content
Snippets Groups Projects
Commit 869b7f7e authored by Colm Talbot's avatar Colm Talbot
Browse files

change logic on testing recaling validity

parent cc3a6f39
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -34,8 +34,8 @@ class Prior(object):
@staticmethod
def test_valid_for_rescaling(val):
"""Test if 0 < val < 1"""
if (val < 0) or (val > 1):
raise ValueError("Number to be rescaled should be in [0, 1]")
allowed = (val > 0) and (val < 1)
return allowed
def __repr__(self):
prior_name = self.__class__.__name__
......@@ -97,7 +97,8 @@ class Uniform(Prior):
self.support = maximum - minimum
def rescale(self, val):
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return self.minimum + val * self.support
def prob(self, val):
......@@ -117,7 +118,8 @@ class DeltaFunction(Prior):
def rescale(self, val):
"""Rescale everything to the peak with the correct shape."""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return self.peak * val ** 0
def prob(self, val):
......@@ -144,7 +146,8 @@ class PowerLaw(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
if self.alpha == -1:
return self.minimum * np.exp(val * np.log(self.maximum / self.minimum))
else:
......@@ -174,7 +177,8 @@ class Cosine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return np.arcsin(-1 + val * 2)
@staticmethod
......@@ -197,7 +201,8 @@ class Sine(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return np.arccos(-1 + val * 2)
@staticmethod
......@@ -224,7 +229,8 @@ class Gaussian(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return self.mu + erfinv(2 * val - 1) * 2**0.5 * self.sigma
def prob(self, val):
......@@ -256,7 +262,8 @@ class TruncatedGaussian(Prior):
This maps to the inverse CDF. This has been analytically solved for this case.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return erfinv(2 * val * self.normalisation + erf(
(self.minimum - self.mu) / 2 ** 0.5 / self.sigma)) * 2 ** 0.5 * self.sigma + self.mu
......@@ -305,7 +312,8 @@ class Interped(Prior):
This maps to the inverse CDF. This is done using interpolation.
"""
Prior.test_valid_for_rescaling(val)
if not all(Prior.test_valid_for_rescaling(val)):
logging.warning('Values to be rescaled must lie in [0, 1].')
return self.inverse_cumulative_distribution(val)
def __repr__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment