diff --git a/test/prior_tests.py b/test/prior_tests.py index 86a8a90d69d22f95150a996e535d958ca9e2d139..bbed21f609404280932b7038c49a0e014e2cc3a4 100644 --- a/test/prior_tests.py +++ b/test/prior_tests.py @@ -92,8 +92,9 @@ class TestPriorClasses(unittest.TestCase): tupak.prior.DeltaFunction(name='test', peak=1), tupak.prior.Gaussian(name='test', mu=0, sigma=1), tupak.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1), - tupak.prior.PowerLaw(name='test', alpha=-1, minimum=1, maximum=1e2), + tupak.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2), tupak.prior.Uniform(name='test', minimum=0, maximum=1), + tupak.prior.LogUniform(name='test', minimum=5e0, maximum=1e2), tupak.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3), tupak.prior.Sine(name='test'), tupak.prior.Cosine(name='test'), @@ -102,26 +103,42 @@ class TestPriorClasses(unittest.TestCase): tupak.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1) ] - def test_rescaling(self): + def test_minimum_rescaling(self): + """Test the the rescaling works as expected.""" for prior in self.priors: - """Test the the rescaling works as expected.""" minimum_sample = prior.rescale(0) self.assertAlmostEqual(minimum_sample, prior.minimum) + + def test_maximum_rescaling(self): + """Test the the rescaling works as expected.""" + for prior in self.priors: maximum_sample = prior.rescale(1) self.assertAlmostEqual(maximum_sample, prior.maximum) + + def test_many_sample_rescaling(self): + """Test the the rescaling works as expected.""" + for prior in self.priors: many_samples = prior.rescale(np.random.uniform(0, 1, 1000)) self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum))) + + def test_out_of_bounds_rescaling(self): + """Test the the rescaling works as expected.""" + for prior in self.priors: self.assertRaises(ValueError, lambda: prior.rescale(-1)) - def test_sampling(self): + def test_sampling_single(self): """Test that sampling from the prior always returns values within its domain.""" for prior in self.priors: single_sample = prior.sample() self.assertTrue((single_sample >= prior.minimum) & (single_sample <= prior.maximum)) + + def test_sampling_many(self): + """Test that sampling from the prior always returns values within its domain.""" + for prior in self.priors: many_samples = prior.sample(1000) self.assertTrue(all((many_samples >= prior.minimum) & (many_samples <= prior.maximum))) - def test_prob(self): + def test_probability_above_domain(self): """Test that the prior probability is non-negative in domain of validity and zero outside.""" for prior in self.priors: # skip delta function prior in this case @@ -130,15 +147,36 @@ class TestPriorClasses(unittest.TestCase): if prior.maximum != np.inf: outside_domain = np.linspace(prior.maximum + 1, prior.maximum + 1e4, 1000) self.assertTrue(all(prior.prob(outside_domain) == 0)) + + def test_probability_below_domain(self): + """Test that the prior probability is non-negative in domain of validity and zero outside.""" + for prior in self.priors: + # skip delta function prior in this case + if isinstance(prior, tupak.prior.DeltaFunction): + continue if prior.minimum != -np.inf: outside_domain = np.linspace(prior.minimum - 1e4, prior.minimum - 1, 1000) self.assertTrue(all(prior.prob(outside_domain) == 0)) + + def test_probability_in_domain(self): + """Test that the prior probability is non-negative in domain of validity and zero outside.""" + for prior in self.priors: + # skip delta function prior in this case + if isinstance(prior, tupak.prior.DeltaFunction): + continue if prior.minimum == -np.inf: prior.minimum = -1e5 if prior.maximum == np.inf: prior.maximum = 1e5 domain = np.linspace(prior.minimum, prior.maximum, 1000) self.assertTrue(all(prior.prob(domain) >= 0)) + + def test_probability_surrounding_domain(self): + """Test that the prior probability is non-negative in domain of validity and zero outside.""" + for prior in self.priors: + # skip delta function prior in this case + if isinstance(prior, tupak.prior.DeltaFunction): + continue surround_domain = np.linspace(prior.minimum - 1, prior.maximum + 1, 1000) prior.prob(surround_domain) @@ -151,6 +189,7 @@ class TestPriorClasses(unittest.TestCase): domain = np.linspace(-1e2, 1e2, 1000) else: domain = np.linspace(prior.minimum, prior.maximum, 1000) + print(prior.minimum, prior.maximum) self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3) @@ -189,6 +228,5 @@ class TestFillPrior(unittest.TestCase): self.assertIsInstance(self.priors['ra'], tupak.prior.Uniform) - if __name__ == '__main__': unittest.main()