diff --git a/test/prior_tests.py b/test/prior_tests.py
index 8e1950cbabfeb52f9e2011867a7bfd0a1e602125..8f25a0292d89c7ecaf024c8ea075ca93744e285c 100644
--- a/test/prior_tests.py
+++ b/test/prior_tests.py
@@ -103,32 +103,32 @@ class TestPriorClasses(unittest.TestCase):
     def setUp(self):
 
         self.priors = [
-            tupak.core.prior.DeltaFunction(name='test', peak=1),
-            tupak.core.prior.Gaussian(name='test', mu=0, sigma=1),
-            tupak.core.prior.Normal(name='test', mu=0, sigma=1),
-            tupak.core.prior.PowerLaw(name='test', alpha=0, minimum=0, maximum=1),
-            tupak.core.prior.PowerLaw(name='test', alpha=2, minimum=1, maximum=1e2),
-            tupak.core.prior.Uniform(name='test', minimum=0, maximum=1),
-            tupak.core.prior.LogUniform(name='test', minimum=5e0, maximum=1e2),
+            tupak.core.prior.DeltaFunction(name='test', unit='unit', peak=1),
+            tupak.core.prior.Gaussian(name='test', unit='unit', mu=0, sigma=1),
+            tupak.core.prior.Normal(name='test', unit='unit', mu=0, sigma=1),
+            tupak.core.prior.PowerLaw(name='test', unit='unit', alpha=0, minimum=0, maximum=1),
+            tupak.core.prior.PowerLaw(name='test', unit='unit', alpha=2, minimum=1, maximum=1e2),
+            tupak.core.prior.Uniform(name='test', unit='unit', minimum=0, maximum=1),
+            tupak.core.prior.LogUniform(name='test', unit='unit', minimum=5e0, maximum=1e2),
             tupak.gw.prior.UniformComovingVolume(name='test', minimum=2e2, maximum=5e3),
-            tupak.core.prior.Sine(name='test'),
-            tupak.core.prior.Cosine(name='test'),
-            tupak.core.prior.Interped(name='test', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000) ** 4,
+            tupak.core.prior.Sine(name='test', unit='unit'),
+            tupak.core.prior.Cosine(name='test', unit='unit'),
+            tupak.core.prior.Interped(name='test', unit='unit', xx=np.linspace(0, 10, 1000), yy=np.linspace(0, 10, 1000) ** 4,
                                       minimum=3, maximum=5),
-            tupak.core.prior.TruncatedGaussian(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1),
-            tupak.core.prior.TruncatedNormal(name='test', mu=1, sigma=0.4, minimum=-1, maximum=1),
-            tupak.core.prior.HalfGaussian(name='test', sigma=1),
-            tupak.core.prior.HalfNormal(name='test', sigma=1),
-            tupak.core.prior.LogGaussian(name='test', mu=0, sigma=1),
-            tupak.core.prior.LogNormal(name='test', mu=0, sigma=1),
-            tupak.core.prior.Exponential(name='test', mu=1),
-            tupak.core.prior.StudentT(name='test', df=3, mu=0, scale=1),
-            tupak.core.prior.Beta(name='test', alpha=2.0, beta=2.0),
-            tupak.core.prior.Logistic(name='test', mu=0, scale=1),
-            tupak.core.prior.Cauchy(name='test', alpha=0, beta=1),
-            tupak.core.prior.Lorentzian(name='test', alpha=0, beta=1),
-            tupak.core.prior.Gamma(name='test', k=1, theta=1),
-            tupak.core.prior.ChiSquared(name='test', nu=2)
+            tupak.core.prior.TruncatedGaussian(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1),
+            tupak.core.prior.TruncatedNormal(name='test', unit='unit', mu=1, sigma=0.4, minimum=-1, maximum=1),
+            tupak.core.prior.HalfGaussian(name='test', unit='unit', sigma=1),
+            tupak.core.prior.HalfNormal(name='test', unit='unit', sigma=1),
+            tupak.core.prior.LogGaussian(name='test', unit='unit', mu=0, sigma=1),
+            tupak.core.prior.LogNormal(name='test', unit='unit', mu=0, sigma=1),
+            tupak.core.prior.Exponential(name='test', unit='unit', mu=1),
+            tupak.core.prior.StudentT(name='test', unit='unit', df=3, mu=0, scale=1),
+            tupak.core.prior.Beta(name='test', unit='unit', alpha=2.0, beta=2.0),
+            tupak.core.prior.Logistic(name='test', unit='unit', mu=0, scale=1),
+            tupak.core.prior.Cauchy(name='test', unit='unit', alpha=0, beta=1),
+            tupak.core.prior.Lorentzian(name='test', unit='unit', alpha=0, beta=1),
+            tupak.core.prior.Gamma(name='test', unit='unit', k=1, theta=1),
+            tupak.core.prior.ChiSquared(name='test', unit='unit', nu=2)
         ]
 
     def test_minimum_rescaling(self):
@@ -235,6 +235,13 @@ class TestPriorClasses(unittest.TestCase):
                 domain = np.linspace(prior.minimum, prior.maximum, 1000)
             self.assertAlmostEqual(np.trapz(prior.prob(domain), domain), 1, 3)
 
+    def test_unit_setting(self):
+        for prior in self.priors:
+            if isinstance(prior, tupak.gw.prior.UniformComovingVolume):
+                self.assertEqual('Mpc', prior.unit)
+            else:
+                self.assertEqual('unit', prior.unit)
+
 
 class TestFillPrior(unittest.TestCase):