From 664eeef058b3599c9db16d569d6b27b2a9cbae3a Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 14 May 2018 12:09:04 +1000
Subject: [PATCH] add testing of specific prior distributions

---
 test/prior_tests.py | 59 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 59 insertions(+)

diff --git a/test/prior_tests.py b/test/prior_tests.py
index cc105c4d8..3bad18bbc 100644
--- a/test/prior_tests.py
+++ b/test/prior_tests.py
@@ -111,5 +111,64 @@ class TestFixMethod(unittest.TestCase):
         self.assertRaises(ValueError, tupak.prior.fix, self.prior, np.nan)
 
 
+class TestPriorClasses(unittest.TestCase):
+
+    def setUp(self):
+
+        self.priors = [
+            tupak.prior.DeltaFunction(name='test', peak=1),
+            tupak.prior.Gaussian(name='test', mu=0, sigma=1),
+            tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
+            tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2),
+            tupak.prior.Uniform(name='test', minimum=0, maximum=1),
+            tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
+            tupak.prior.Sine(name='test'),
+            tupak.prior.Cosine(name='test'),
+            tupak.prior.Interped(name='test', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000)**4,
+                                 minimum=3, maximum=5),
+            tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1)
+        ]
+
+    def test_rescaling(self):
+        for prior in self.priors:
+            """Test the the rescaling works as expected."""
+            minimum_sample = prior.rescale(0)
+            self.assertAlmostEqual(minimum_sample, prior.minimum)
+            maximum_sample = prior.rescale(1)
+            self.assertAlmostEqual(maximum_sample, prior.maximum)
+            many_samples = prior.rescale(np.random.uniform(0, 1, 1000))
+            self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
+            self.assertRaises(ValueError, lambda: prior.rescale(-1))
+
+    def test_sampling(self):
+        """Test that sampling from the prior always returns values within its domain."""
+        for prior in self.priors:
+            single_sample = prior.sample()
+            self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum))
+            many_samples = prior.sample(1000)
+            self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum)))
+
+    def test_prob(self):
+        """Test that the prior probability is non-negative in domain of validity and zero outside."""
+        for prior in self.priors:
+            # skip delta function prior in this case
+            if isinstance(prior, tupak.prior.DeltaFunction):
+                continue
+            if prior.maximum != np.inf:
+                outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000)
+                self.assertTrue(all(prior.prob(outside_domain) == 0))
+            if prior.minimum != -np.inf:
+                outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000)
+                self.assertTrue(all(prior.prob(outside_domain) == 0))
+            if prior.minimum == -np.inf:
+                prior.minimum = -1e5
+            if prior.maximum == np.inf:
+                prior.maximum = 1e5
+            domain = np.linspace(prior.minimum, prior.maximum, 1000)
+            self.assertTrue(all(prior.prob(domain) >= 0))
+            surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000)
+            prior.prob(surround_domain)
+
+
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab