From 664eeef058b3599c9db16d569d6b27b2a9cbae3a Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Mon, 14 May 2018 12:09:04 +1000 Subject: [PATCH] add testing of specific prior distributions --- test/prior_tests.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/test/prior_tests.py b/test/prior_tests.py index cc105c4d8..3bad18bbc 100644 --- a/test/prior_tests.py +++ b/test/prior_tests.py @@ -111,5 +111,64 @@ class TestFixMethod(unittest.TestCase): self.assertRaises(ValueError, tupak.prior.fix, self.prior, np.nan) +class TestPriorClasses(unittest.TestCase): + + def setUp(self): + + self.priors = [ + tupak.prior.DeltaFunction(name='test', peak=1), + tupak.prior.Gaussian(name='test', mu=0, sigma=1), + tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1), + tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2), + tupak.prior.Uniform(name='test', minimum=0, maximum=1), + tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3), + tupak.prior.Sine(name='test'), + tupak.prior.Cosine(name='test'), + tupak.prior.Interped(name='test', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000)**4, + minimum=3, maximum=5), + tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1) + ] + + def test_rescaling(self): + for prior in self.priors: + """Test the the rescaling works as expected.""" + minimum_sample = prior.rescale(0) + self.assertAlmostEqual(minimum_sample, prior.minimum) + maximum_sample = prior.rescale(1) + self.assertAlmostEqual(maximum_sample, prior.maximum) + many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) + self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum))) + self.assertRaises(ValueError, lambda: prior.rescale(-1)) + + def test_sampling(self): + """Test that sampling from the prior always returns values within its domain.""" + for prior in self.priors: + single_sample = prior.sample() + self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum)) + many_samples = prior.sample(1000) + self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum))) + + def test_prob(self): + """Test that the prior probability is non-negative in domain of validity and zero outside.""" + for prior in self.priors: + # skip delta function prior in this case + if isinstance(prior, tupak.prior.DeltaFunction): + continue + if prior.maximum != np.inf: + outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000) + self.assertTrue(all(prior.prob(outside_domain) == 0)) + if prior.minimum != -np.inf: + outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) + self.assertTrue(all(prior.prob(outside_domain) == 0)) + if prior.minimum == -np.inf: + prior.minimum = -1e5 + if prior.maximum == np.inf: + prior.maximum = 1e5 + domain = np.linspace(prior.minimum, prior.maximum, 1000) + self.assertTrue(all(prior.prob(domain) >= 0)) + surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) + prior.prob(surround_domain) + + if __name__ == '__main__': unittest.main() -- GitLab