diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py index 1b51fef564792a208b03a9916bdd11d1607c1dd0..375124147a1dbbbb36e972dc0f0cda15adb969c1 100644 --- a/bilby/gw/prior.py +++ b/bilby/gw/prior.py @@ -20,13 +20,17 @@ class Cosmological(Interped): _default_args_dict = dict( redshift=dict(name='redshift', latex_label='$z$', unit=None), luminosity_distance=dict( - name='luminosity_distance', latex_label='$d_L$', unit='Mpc'), + name='luminosity_distance', latex_label='$d_L$', unit=units.Mpc), comoving_distance=dict( - name='comoving_distance', latex_label='$d_L$', unit='Mpc')) + name='comoving_distance', latex_label='$d_C$', unit=units.Mpc)) def __init__(self, minimum, maximum, cosmology=None, name=None, latex_label=None, unit=None): self.cosmology = get_cosmology(cosmology) + if name not in self._default_args_dict: + raise ValueError( + "Name {} not recognised. Must be one of luminosity_distance, " + "comoving_distance, redshift".format(name)) self.name = name label_args = self._default_args_dict[self.name] if latex_label is not None: @@ -35,7 +39,7 @@ class Cosmological(Interped): if isinstance(unit, str): unit = units.__dict__[unit] label_args['unit'] = unit - self.unit = unit + self.unit = label_args['unit'] self._minimum = dict() self._maximum = dict() self.minimum = minimum @@ -138,12 +142,6 @@ class Cosmological(Interped): class UniformComovingVolume(Cosmological): - def __init__(self, minimum, maximum, cosmology=None, - name='luminosity_distance', latex_label='$d_L$', unit='Mpc'): - Cosmological.__init__( - self, minimum=minimum, maximum=maximum, cosmology=cosmology, - name=name, latex_label=latex_label, unit=unit) - def _get_redshift_arrays(self): zs = np.linspace(self._minimum['redshift'] * 0.99, self._maximum['redshift'] * 1.01, 1000) diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py index a3bf4e6891755677d000380f3286e01002a8e0b2..f7c7454e9472f496b048045e1719b23f08321cfa 100644 --- a/test/gw_prior_test.py +++ b/test/gw_prior_test.py @@ -79,31 +79,34 @@ class TestUniformComovingVolumePrior(unittest.TestCase): pass def test_minimum(self): - prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000) + prior = bilby.gw.prior.UniformComovingVolume( + minimum=10, maximum=10000, name='luminosity_distance') self.assertEqual(prior.minimum, 10) def test_maximum(self): - prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000) + prior = bilby.gw.prior.UniformComovingVolume( + minimum=10, maximum=10000, name='luminosity_distance') self.assertEqual(prior.maximum, 10000) def test_zero_minimum_works(self): - prior = bilby.gw.prior.UniformComovingVolume(minimum=0, maximum=10000) + prior = bilby.gw.prior.UniformComovingVolume( + minimum=0, maximum=10000, name='luminosity_distance') self.assertEqual(prior.minimum, 0) def test_specify_cosmology(self): prior = bilby.gw.prior.UniformComovingVolume( - minimum=10, maximum=10000, cosmology='Planck13') + minimum=10, maximum=10000, name='luminosity_distance', + cosmology='Planck13') self.assertEqual(repr(prior.cosmology), repr(cosmology.Planck13)) def test_comoving_prior_creation(self): prior = bilby.gw.prior.UniformComovingVolume( - minimum=0.1, maximum=1, name='comoving_distance', - latex_label='$d_C$') + minimum=10, maximum=1000, name='comoving_distance') self.assertEqual(prior.latex_label, '$d_C$') def test_redshift_prior_creation(self): prior = bilby.gw.prior.UniformComovingVolume( - minimum=0.1, maximum=1, name='redshift', latex_label='$z$') + minimum=0.1, maximum=1, name='redshift') self.assertEqual(prior.latex_label, '$z$') def test_redshift_to_luminosity_distance(self): @@ -113,12 +116,14 @@ class TestUniformComovingVolumePrior(unittest.TestCase): self.assertEqual(new_prior.name, 'luminosity_distance') def test_luminosity_distance_to_redshift(self): - prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000) + prior = bilby.gw.prior.UniformComovingVolume( + minimum=10, maximum=10000, name='luminosity_distance') new_prior = prior.get_corresponding_prior('redshift') self.assertEqual(new_prior.name, 'redshift') def test_luminosity_distance_to_comoving_distance(self): - prior = bilby.gw.prior.UniformComovingVolume(minimum=10, maximum=10000) + prior = bilby.gw.prior.UniformComovingVolume( + minimum=10, maximum=10000, name='luminosity_distance') new_prior = prior.get_corresponding_prior('comoving_distance') self.assertEqual(new_prior.name, 'comoving_distance') diff --git a/test/prior_test.py b/test/prior_test.py index aeb6a24b856685e1360d74a1460c95d3e47063c5..d8d4041de3a719e4a4238945f60e522b3a641388 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -133,7 +133,7 @@ class TestPriorClasses(unittest.TestCase): bilby.core.prior.PowerLaw(name='test', unit='unit', alpha=2, minimum=1, maximum=1e2), bilby.core.prior.Uniform(name='test', unit='unit', minimum=0, maximum=1), bilby.core.prior.LogUniform(name='test', unit='unit', minimum=5e0, maximum=1e2), - bilby.gw.prior.UniformComovingVolume(name='redshift', unit=None, minimum=0.1, maximum=1.0), + bilby.gw.prior.UniformComovingVolume(name='redshift', minimum=0.1, maximum=1.0), bilby.core.prior.Sine(name='test', unit='unit'), bilby.core.prior.Cosine(name='test', unit='unit'), bilby.core.prior.Interped(name='test', unit='unit', xx=np.linspace(0, 10, 1000),