From ae4fc1d879bd77724ba7479b4a220cccc6ccc82c Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 14 Jan 2022 14:31:27 +0000
Subject: [PATCH] Wrap astropy z_at_value for v5.

---
 bilby/gw/conversion.py |  4 +---
 bilby/gw/cosmology.py  | 12 ++++++++++++
 bilby/gw/prior.py      |  3 +--
 3 files changed, 14 insertions(+), 5 deletions(-)

diff --git a/bilby/gw/conversion.py b/bilby/gw/conversion.py
index 492a6201f..9c516c737 100644
--- a/bilby/gw/conversion.py
+++ b/bilby/gw/conversion.py
@@ -11,7 +11,7 @@ from ..core.utils import logger, solar_mass, command_line_args
 from ..core.prior import DeltaFunction
 from .utils import lalsim_SimInspiralTransformPrecessingNewInitialConditions
 from .eos.eos import SpectralDecompositionEOS, EOSFamily, IntegrateTOV
-from .cosmology import get_cosmology
+from .cosmology import get_cosmology, z_at_value
 
 
 def redshift_to_luminosity_distance(redshift, cosmology=None):
@@ -27,7 +27,6 @@ def redshift_to_comoving_distance(redshift, cosmology=None):
 @np.vectorize
 def luminosity_distance_to_redshift(distance, cosmology=None):
     from astropy import units
-    from astropy.cosmology import z_at_value
     cosmology = get_cosmology(cosmology)
     return z_at_value(cosmology.luminosity_distance, distance * units.Mpc)
 
@@ -35,7 +34,6 @@ def luminosity_distance_to_redshift(distance, cosmology=None):
 @np.vectorize
 def comoving_distance_to_redshift(distance, cosmology=None):
     from astropy import units
-    from astropy.cosmology import z_at_value
     cosmology = get_cosmology(cosmology)
     return z_at_value(cosmology.comoving_distance, distance * units.Mpc)
 
diff --git a/bilby/gw/cosmology.py b/bilby/gw/cosmology.py
index ded3f5eb6..c05c7ada3 100644
--- a/bilby/gw/cosmology.py
+++ b/bilby/gw/cosmology.py
@@ -63,3 +63,15 @@ def set_cosmology(cosmology=None):
         COSMOLOGY[1] = cosmology.name
     else:
         COSMOLOGY[1] = repr(cosmology)
+
+
+def z_at_value(func, fval, **kwargs):
+    """
+    Wrapped version of :code:`astropy.cosmology.z_at_value` to return float
+    rather than an :code:`astropy Quantity` as returned for :code:`astropy>=5`.
+
+    See https://docs.astropy.org/en/stable/api/astropy.cosmology.z_at_value.html#astropy.cosmology.z_at_value
+    for detailed documentation.
+    """
+    from astropy.cosmology import z_at_value
+    return float(z_at_value(func=func, fval=fval, **kwargs))
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 7f63d91ee..547759740 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -20,7 +20,7 @@ from .conversion import (
     generate_all_bbh_parameters,
     chirp_mass_and_mass_ratio_to_total_mass,
     total_mass_and_mass_ratio_to_component_masses)
-from .cosmology import get_cosmology
+from .cosmology import get_cosmology, z_at_value
 from .source import PARAMETER_SETS
 from .utils import calculate_time_to_merger
 
@@ -170,7 +170,6 @@ class Cosmological(Interped):
         recalculate_array: boolean
             Determines if the distance arrays are recalculated
         """
-        from astropy.cosmology import z_at_value
         cosmology = get_cosmology(self.cosmology)
         limit_dict[self.name] = value
         if self.name == 'redshift':
-- 
GitLab