diff --git a/test/prior_tests.py b/test/prior_tests.py index 8e1950cbabfeb52f9e2011867a7bfd0a1e602125..8f25a0292d89c7ecaf024c8ea075ca93744e285c 100644 --- a/test/prior_tests.py +++ b/test/prior_tests.py @@ -103,32 +103,32 @@ class TestPriorClasses(unittest.TestCase): def setUp(self): self.priors = [ - tupak.core.prior.DeltaFunction(name='test', peak=1), - tupak.core.prior.Gaussian(name='test', mu=0, sigma=1), - tupak.core.prior.Normal(name='test', mu=0, sigma=1), - tupak.core.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1), - tupak.core.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2), - tupak.core.prior.Uniform(name='test', minimum=0, maximum=1), - tupak.core.prior.LogUniform(name='test', minimum=5e0, maximum=1e2), + tupak.core.prior.DeltaFunction(name='test', unit='unit', peak=1), + tupak.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1), + tupak.core.prior.Normal(name='test', unit='unit', mu=0, sigma=1), + tupak.core.prior.PowerLaw(name='test', unit='unit', alpha=0, minimum=0, maximum=1), + tupak.core.prior.PowerLaw(name='test', unit='unit', alpha=2, minimum=1, maximum=1e2), + tupak.core.prior.Uniform(name='test', unit='unit', minimum=0, maximum=1), + tupak.core.prior.LogUniform(name='test', unit='unit', minimum=5e0, maximum=1e2), tupak.gw.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3), - tupak.core.prior.Sine(name='test'), - tupak.core.prior.Cosine(name='test'), - tupak.core.prior.Interped(name='test', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000) ** 4, + tupak.core.prior.Sine(name='test', unit='unit'), + tupak.core.prior.Cosine(name='test', unit='unit'), + tupak.core.prior.Interped(name='test', unit='unit', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000) ** 4, minimum=3, maximum=5), - tupak.core.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1), - tupak.core.prior.TruncatedNormal(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1), - tupak.core.prior.HalfGaussian(name='test', sigma=1), - tupak.core.prior.HalfNormal(name='test', sigma=1), - tupak.core.prior.LogGaussian(name='test', mu=0, sigma=1), - tupak.core.prior.LogNormal(name='test', mu=0, sigma=1), - tupak.core.prior.Exponential(name='test', mu=1), - tupak.core.prior.StudentT(name='test', df=3, mu=0, scale=1), - tupak.core.prior.Beta(name='test', alpha=2.0, beta=2.0), - tupak.core.prior.Logistic(name='test', mu=0, scale=1), - tupak.core.prior.Cauchy(name='test', alpha=0, beta=1), - tupak.core.prior.Lorentzian(name='test', alpha=0, beta=1), - tupak.core.prior.Gamma(name='test', k=1, theta=1), - tupak.core.prior.ChiSquared(name='test', nu=2) + tupak.core.prior.TruncatedGaussian(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1), + tupak.core.prior.TruncatedNormal(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1), + tupak.core.prior.HalfGaussian(name='test', unit='unit', sigma=1), + tupak.core.prior.HalfNormal(name='test', unit='unit', sigma=1), + tupak.core.prior.LogGaussian(name='test', unit='unit', mu=0, sigma=1), + tupak.core.prior.LogNormal(name='test', unit='unit', mu=0, sigma=1), + tupak.core.prior.Exponential(name='test', unit='unit', mu=1), + tupak.core.prior.StudentT(name='test', unit='unit', df=3, mu=0, scale=1), + tupak.core.prior.Beta(name='test', unit='unit', alpha=2.0, beta=2.0), + tupak.core.prior.Logistic(name='test', unit='unit', mu=0, scale=1), + tupak.core.prior.Cauchy(name='test', unit='unit', alpha=0, beta=1), + tupak.core.prior.Lorentzian(name='test', unit='unit', alpha=0, beta=1), + tupak.core.prior.Gamma(name='test', unit='unit', k=1, theta=1), + tupak.core.prior.ChiSquared(name='test', unit='unit', nu=2) ] def test_minimum_rescaling(self): @@ -235,6 +235,13 @@ class TestPriorClasses(unittest.TestCase): domain = np.linspace(prior.minimum, prior.maximum, 1000) self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3) + def test_unit_setting(self): + for prior in self.priors: + if isinstance(prior, tupak.gw.prior.UniformComovingVolume): + self.assertEqual('Mpc', prior.unit) + else: + self.assertEqual('unit', prior.unit) + class TestFillPrior(unittest.TestCase):