From 266b5e2e813e3f5bf0c70792dddc30d87050983a Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Mon, 28 Jan 2019 12:43:42 +1100
Subject: [PATCH] introduce Cosmological prior which UCV subclasses

---
 bilby/gw/prior.py | 205 +++++++++++++++++++++++++++++-----------------
 1 file changed, 130 insertions(+), 75 deletions(-)

diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index dea62c2f..1b51fef5 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -5,7 +5,7 @@ from scipy.interpolate import UnivariateSpline
 
 from ..core.prior import (PriorDict, Uniform, Prior, DeltaFunction, Gaussian,
                           Interped)
-from ..core.utils import logger
+from ..core.utils import infer_args_from_method, logger
 from .cosmology import get_cosmology
 
 try:
@@ -15,85 +15,140 @@ except ImportError:
                  " not be able to use some of the prebuilt functions.")
 
 
-class UniformComovingVolume(Interped):
-
-    def __init__(self, minimum=None, maximum=None, cosmology=None,
-                 name='luminosity distance', unit='Mpc'):
-        """
-        Prior distribution on either luminosity distance or redshift which is
-        uniform in comoving volume based on the specified cosmology.
-
-        The default cosmology is Planck15 as implemented in astropy.
-
-        Parameters
-        ----------
-        minimum: float, optional
-            See superclass
-        maximum: float, optional
-            See superclass
-        cosmology: (astropy.cosmology.FlatLambdaCDM, str), optional
-            Cosmology to use. If a string that string will be searched for in
-            astropy.cosmology. Default is project cosmology
-        name: str, optional
-            Name specifying which distance parameter to use, the options are:
-                - luminosity_distance
-                - redshift
-        unit: (astropy.units.Mpc, str), optional
-            Units, if a string that string will be searched for in
-            astropy.units. Default=Mpc
-        """
-        cosmology = get_cosmology(cosmology)
-        if isinstance(unit, str):
-            unit = units.__dict__[unit]
-        if cosmology.name is not None:
-            self.cosmology = cosmology.name
-        else:
-            self.cosmology = cosmology
+class Cosmological(Interped):
+
+    _default_args_dict = dict(
+        redshift=dict(name='redshift', latex_label='$z$', unit=None),
+        luminosity_distance=dict(
+            name='luminosity_distance', latex_label='$d_L$', unit='Mpc'),
+        comoving_distance=dict(
+            name='comoving_distance', latex_label='$d_L$', unit='Mpc'))
+
+    def __init__(self, minimum, maximum, cosmology=None, name=None,
+                 latex_label=None, unit=None):
+        self.cosmology = get_cosmology(cosmology)
+        self.name = name
+        label_args = self._default_args_dict[self.name]
+        if latex_label is not None:
+            label_args['latex_label'] = latex_label
+        if unit is not None:
+            if isinstance(unit, str):
+                unit = units.__dict__[unit]
+            label_args['unit'] = unit
+        self.unit = unit
+        self._minimum = dict()
+        self._maximum = dict()
+        self.minimum = minimum
+        self.maximum = maximum
         if name == 'redshift':
-            zs = np.linspace(minimum, maximum, 1000)
-            dvc_dz = cosmology.differential_comoving_volume(zs).value
-            Interped.__init__(
-                self, xx=zs, yy=dvc_dz, minimum=minimum, maximum=maximum,
-                name=name, latex_label='$z$')
+            xx, yy = self._get_redshift_arrays()
+        elif name == 'comoving_distance':
+            xx, yy = self._get_comoving_distance_arrays()
+        elif name == 'luminosity_distance':
+            xx, yy = self._get_luminosity_distance_arrays()
         else:
+            raise ValueError('Name {} not recognized.'.format(name))
+        Interped.__init__(self, xx=xx, yy=yy, minimum=minimum, maximum=maximum,
+                          **label_args)
+
+    @property
+    def minimum(self):
+        return self._minimum[self.name]
+
+    @minimum.setter
+    def minimum(self, minimum):
+        cosmology = get_cosmology(self.cosmology)
+        self._minimum[self.name] = minimum
+        if self.name == 'redshift':
+            self._minimum['luminosity_distance'] =\
+                cosmology.luminosity_distance(minimum).value
+            self._minimum['comoving_distance'] =\
+                cosmology.comoving_distance(minimum).value
+        elif self.name == 'luminosity_distance':
             if minimum == 0:
-                z_min = 0
+                self._minimum['redshift'] = 0
             else:
-                z_min = cosmo.z_at_value(
-                    cosmology.luminosity_distance, minimum * unit)
-            z_max = cosmo.z_at_value(
-                cosmology.luminosity_distance, maximum * unit)
-            zs = np.linspace(z_min * 0.99, z_max * 1.01, 1000)
-            dvc_dz = cosmology.differential_comoving_volume(zs).value
-            dl_of_z = np.array([cosmology.luminosity_distance(z).value
-                                for z in zs])
-            ddl_dz = np.gradient(dl_of_z, zs)
-            dvc_ddl = dvc_dz / ddl_dz
-            Interped.__init__(
-                self, xx=dl_of_z, yy=dvc_ddl, minimum=minimum,
-                maximum=maximum, name=name, latex_label='$d_L$', unit=unit)
-
-    def get_redshift_prior(self):
+                self._minimum['redshift'] = cosmo.z_at_value(
+                    cosmology.luminosity_distance, minimum * self.unit)
+            self._minimum['comoving_distance'] = self._minimum['redshift']
+        elif self.name == 'comoving_distance':
+            if minimum == 0:
+                self._minimum['redshift'] = 0
+            else:
+                self._minimum['redshift'] = cosmo.z_at_value(
+                    cosmology.comoving_distance, minimum * self.unit)
+            self._minimum['luminosity_distance'] = self._minimum['redshift']
+        if getattr(self._maximum, self.name, np.inf) < np.inf:
+            self.__update_instance()
+
+    @property
+    def maximum(self):
+        return self._maximum[self.name]
+
+    @maximum.setter
+    def maximum(self, maximum):
+        cosmology = get_cosmology(self.cosmology)
+        self._maximum[self.name] = maximum
         if self.name == 'redshift':
-            return self
-        else:
-            cosmology = get_cosmology(self.cosmology)
-            return UniformComovingVolume(
-                minimum=cosmo.z_at_value(
-                    cosmology.luminosity_distance, self.minimum * self.unit),
-                maximum=cosmo.z_at_value(
-                    cosmology.luminosity_distance, self.minimum * self.unit),
-                name='redshift', cosmology=cosmology)
-
-    def get_luminosity_distance_prior(self):
-        if self.name == 'luminosity_distance':
-            return self
-        else:
-            cosmology = get_cosmology(self.cosmology)
-            return UniformComovingVolume(
-                minimum=cosmology.luminosity_distance(self.minimum).value,
-                maximum=cosmology.luminosity_distance(self.maximum).value,
-                name='luminosity_distance', cosmology=cosmology)
+            self._maximum['luminosity_distance'] = \
+                cosmology.luminosity_distance(maximum).value
+            self._maximum['comoving_distance'] = \
+                cosmology.comoving_distance(maximum).value
+        elif self.name == 'luminosity_distance':
+            self._maximum['redshift'] = cosmo.z_at_value(
+                cosmology.luminosity_distance, maximum * self.unit)
+            self._maximum['comoving_distance'] = self._maximum['redshift']
+        elif self.name == 'comoving_distance':
+            self._maximum['redshift'] = cosmo.z_at_value(
+                cosmology.comoving_distance, maximum * self.unit)
+            self._maximum['luminosity_distance'] = self._maximum['redshift']
+        if getattr(self._minimum, self.name, np.inf) < np.inf:
+            self.__update_instance()
+
+    def get_corresponding_prior(self, name=None, unit=None):
+        subclass_args = infer_args_from_method(self.__init__)
+        args_dict = {key: getattr(self, key) for key in subclass_args}
+        self._convert_to(new=name, args_dict=args_dict)
+        if unit is not None:
+            args_dict['unit'] = unit
+        return self.__class__(**args_dict)
+
+    def _convert_to(self, new, args_dict):
+        args_dict.update(self._default_args_dict[new])
+        args_dict['minimum'] = self._minimum[args_dict['name']]
+        args_dict['maximum'] = self._maximum[args_dict['name']]
+
+    def _get_comoving_distance_arrays(self):
+        zs, p_dz = self._get_redshift_arrays()
+        dc_of_z = self.cosmology.comoving_distance(zs).value
+        ddc_dz = np.gradient(dc_of_z, zs)
+        p_dc = p_dz / ddc_dz
+        return dc_of_z, p_dc
+
+    def _get_luminosity_distance_arrays(self):
+        zs, p_dz = self._get_redshift_arrays()
+        dl_of_z = self.cosmology.luminosity_distance(zs).value
+        ddl_dz = np.gradient(dl_of_z, zs)
+        p_dl = p_dz / ddl_dz
+        return dl_of_z, p_dl
+
+    def _get_redshift_arrays(self):
+        raise NotImplementedError
+
+
+class UniformComovingVolume(Cosmological):
+
+    def __init__(self, minimum, maximum, cosmology=None,
+                 name='luminosity_distance', latex_label='$d_L$', unit='Mpc'):
+        Cosmological.__init__(
+            self, minimum=minimum, maximum=maximum, cosmology=cosmology,
+            name=name, latex_label=latex_label, unit=unit)
+
+    def _get_redshift_arrays(self):
+        zs = np.linspace(self._minimum['redshift'] * 0.99,
+                         self._maximum['redshift'] * 1.01, 1000)
+        p_dz = self.cosmology.differential_comoving_volume(zs).value
+        return zs, p_dz
 
 
 class AlignedSpin(Interped):
-- 
GitLab