Skip to content
Snippets Groups Projects
Commit 3183d305 authored by Colm Talbot's avatar Colm Talbot
Browse files

update prior tests to include uniform in log

parent 04902daf
No related branches found
No related tags found
1 merge request!42Add uniform in log prior
Pipeline #
......@@ -92,8 +92,9 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.DeltaFunction(name='test', peak=1),
tupak.prior.Gaussian(name='test', mu=0, sigma=1),
tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2),
tupak.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2),
tupak.prior.Uniform(name='test', minimum=0, maximum=1),
tupak.prior.LogUniform(name='test', minimum=5e0, maximum=1e2),
tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
tupak.prior.Sine(name='test'),
tupak.prior.Cosine(name='test'),
......@@ -102,26 +103,42 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1)
]
def test_rescaling(self):
def test_minimum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
"""Test the the rescaling works as expected."""
minimum_sample = prior.rescale(0)
self.assertAlmostEqual(minimum_sample, prior.minimum)
def test_maximum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
maximum_sample = prior.rescale(1)
self.assertAlmostEqual(maximum_sample, prior.maximum)
def test_many_sample_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_out_of_bounds_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
self.assertRaises(ValueError, lambda: prior.rescale(-1))
def test_sampling(self):
def test_sampling_single(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
single_sample = prior.sample()
self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum))
def test_sampling_many(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
many_samples = prior.sample(1000)
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_prob(self):
def test_probability_above_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
......@@ -130,15 +147,36 @@ class TestPriorClasses(unittest.TestCase):
if prior.maximum != np.inf:
outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_below_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum != -np.inf:
outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_in_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum == -np.inf:
prior.minimum = -1e5
if prior.maximum == np.inf:
prior.maximum = 1e5
domain = np.linspace(prior.minimum, prior.maximum, 1000)
self.assertTrue(all(prior.prob(domain) >= 0))
def test_probability_surrounding_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000)
prior.prob(surround_domain)
......@@ -151,6 +189,7 @@ class TestPriorClasses(unittest.TestCase):
domain = np.linspace(-1e2, 1e2, 1000)
else:
domain = np.linspace(prior.minimum, prior.maximum, 1000)
print(prior.minimum, prior.maximum)
self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
......@@ -189,6 +228,5 @@ class TestFillPrior(unittest.TestCase):
self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment