Skip to content
Snippets Groups Projects
Commit 664eeef0 authored by Colm Talbot's avatar Colm Talbot
Browse files

add testing of specific prior distributions

parent 72f7da3d
No related branches found
No related tags found
1 merge request!35Prior tests
......@@ -111,5 +111,64 @@ class TestFixMethod(unittest.TestCase):
self.assertRaises(ValueError, tupak.prior.fix, self.prior, np.nan)
class TestPriorClasses(unittest.TestCase):
def setUp(self):
self.priors = [
tupak.prior.DeltaFunction(name='test', peak=1),
tupak.prior.Gaussian(name='test', mu=0, sigma=1),
tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2),
tupak.prior.Uniform(name='test', minimum=0, maximum=1),
tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
tupak.prior.Sine(name='test'),
tupak.prior.Cosine(name='test'),
tupak.prior.Interped(name='test', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000)**4,
minimum=3, maximum=5),
tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1)
]
def test_rescaling(self):
for prior in self.priors:
"""Test the the rescaling works as expected."""
minimum_sample = prior.rescale(0)
self.assertAlmostEqual(minimum_sample, prior.minimum)
maximum_sample = prior.rescale(1)
self.assertAlmostEqual(maximum_sample, prior.maximum)
many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
self.assertRaises(ValueError, lambda: prior.rescale(-1))
def test_sampling(self):
"""Test that sampling from the prior always returns values within its domain."""
for prior in self.priors:
single_sample = prior.sample()
self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum))
many_samples = prior.sample(1000)
self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
def test_prob(self):
"""Test that the prior probability is non-negative in domain of validity and zero outside."""
for prior in self.priors:
# skip delta function prior in this case
if isinstance(prior, tupak.prior.DeltaFunction):
continue
if prior.maximum != np.inf:
outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
if prior.minimum != -np.inf:
outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
self.assertTrue(all(prior.prob(outside_domain) == 0))
if prior.minimum == -np.inf:
prior.minimum = -1e5
if prior.maximum == np.inf:
prior.maximum = 1e5
domain = np.linspace(prior.minimum, prior.maximum, 1000)
self.assertTrue(all(prior.prob(domain) >= 0))
surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000)
prior.prob(surround_domain)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment