Commit 3183d305 authored by Colm Talbot's avatar Colm Talbot

update prior tests to include uniform in log

parent 04902daf
Pipeline #19572 passed with stages
in 6 minutes and 3 seconds
......@@ -92,8 +92,9 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.DeltaFunction(name='test', peak=1),
tupak.prior.Gaussian(name='test', mu=0, sigma=1),
tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2),
tupak.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2),
tupak.prior.Uniform(name='test', minimum=0, maximum=1),
tupak.prior.LogUniform(name='test', minimum=5e0, maximum=1e2),
tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
tupak.prior.Sine(name='test'),
tupak.prior.Cosine(name='test'),
......@@ -102,26 +103,42 @@ class TestPriorClasses(unittest.TestCase):
tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1)
]
def test_rescaling(self):
def test_minimum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
"""Test the the rescaling works as expected."""
minimum_sample = prior.rescale(0)
self.assertAlmostEqual(minimum_sample, prior.minimum)
def test_maximum_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
maximum_sample = prior.rescale(1)
self.assertAlmostEqual(maximum_sample, prior.maximum)
def test_many_sample_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_out_of_bounds_rescaling(self):
"""Test the the rescaling works as expected."""
for prior in self.priors:
self.assertRaises(ValueError, lambda: prior.rescale(-1))
def test_sampling(self):
def test_sampling_single(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
single_sample = prior.sample()
self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum))
def test_sampling_many(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
many_samples = prior.sample(1000)
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_prob(self):
def test_probability_above_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
......@@ -130,15 +147,36 @@ class TestPriorClasses(unittest.TestCase):
if prior.maximum != np.inf:
outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_below_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum != -np.inf:
outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
def test_probability_in_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.minimum == -np.inf:
prior.minimum = -1e5
if prior.maximum == np.inf:
prior.maximum = 1e5
domain = np.linspace(prior.minimum, prior.maximum, 1000)
self.assertTrue(all(prior.prob(domain) >= 0))
def test_probability_surrounding_domain(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000)
prior.prob(surround_domain)
......@@ -151,6 +189,7 @@ class TestPriorClasses(unittest.TestCase):
domain = np.linspace(-1e2, 1e2, 1000)
else:
domain = np.linspace(prior.minimum, prior.maximum, 1000)
print(prior.minimum, prior.maximum)
self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
......@@ -189,6 +228,5 @@ class TestFillPrior(unittest.TestCase):
self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment