From 485708b05c27217c9d8ab861355a6af63b71ead6 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Wed, 14 Jun 2023 21:26:14 +0000
Subject: [PATCH] Merge branch 'read-write-cosmological-prior' into 'master'

BUGFIX: make sure cosmological priors can be written and read

See merge request lscsoft/bilby!1258

(cherry picked from commit 5231e3b576646b911f19d70971dfda231019ed4a)

2502c9bd BUGFIX: make sure cosmological priors can be written and read
---
 bilby/core/prior/base.py | 12 ------------
 bilby/gw/prior.py        | 21 ++++++++-------------
 test/gw/prior_test.py    |  8 ++++++++
 3 files changed, 16 insertions(+), 25 deletions(-)

diff --git a/bilby/core/prior/base.py b/bilby/core/prior/base.py
index 1917fb3d0..eef710a60 100644
--- a/bilby/core/prior/base.py
+++ b/bilby/core/prior/base.py
@@ -222,18 +222,6 @@ class Prior(object):
         else:
             return f"{prior_module}.{prior_name}({args})"
 
-    @property
-    def _repr_dict(self):
-        """
-        Get a dictionary containing the arguments needed to reproduce this object.
-        """
-        property_names = {p for p in dir(self.__class__) if isinstance(getattr(self.__class__, p), property)}
-        subclass_args = infer_args_from_method(self.__init__)
-        dict_with_properties = self.__dict__.copy()
-        for key in property_names.intersection(subclass_args):
-            dict_with_properties[key] = getattr(self, key)
-        return {key: dict_with_properties[key] for key in subclass_args}
-
     @property
     def is_fixed(self):
         """
diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 455296e55..f8b919c45 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -292,20 +292,15 @@ class Cosmological(Interped):
         else:
             return cls._from_repr(string)
 
-    @property
-    def _repr_dict(self):
-        """
-        Get a dictionary containing the arguments needed to reproduce this object.
-        """
-        from astropy.cosmology.core import Cosmology
+    def get_instantiation_dict(self):
         from astropy import units
-        dict_with_properties = super(Cosmological, self)._repr_dict
-        if isinstance(dict_with_properties['cosmology'], Cosmology):
-            if dict_with_properties['cosmology'].name is not None:
-                dict_with_properties['cosmology'] = dict_with_properties['cosmology'].name
-        if isinstance(dict_with_properties['unit'], units.Unit):
-            dict_with_properties['unit'] = dict_with_properties['unit'].to_string()
-        return dict_with_properties
+        from astropy.cosmology.realizations import available
+        instantiation_dict = super().get_instantiation_dict()
+        if self.cosmology.name in available:
+            instantiation_dict['cosmology'] = self.cosmology.name
+        if isinstance(self.unit, units.Unit):
+            instantiation_dict['unit'] = self.unit.to_string()
+        return instantiation_dict
 
 
 class UniformComovingVolume(Cosmological):
diff --git a/test/gw/prior_test.py b/test/gw/prior_test.py
index 832f2a9d5..5a5d3b3ff 100644
--- a/test/gw/prior_test.py
+++ b/test/gw/prior_test.py
@@ -37,6 +37,14 @@ class TestBBHPriorDict(unittest.TestCase):
         del self.bbh_prior_dict
         del self.base_directory
 
+    def test_read_write_default_prior(self):
+        filename = "test_prior.prior"
+        self.bbh_prior_dict.to_file(outdir=".", label="test_prior")
+        new_prior = bilby.gw.prior.BBHPriorDict(filename=filename)
+        for key in self.bbh_prior_dict:
+            self.assertEqual(self.bbh_prior_dict[key], new_prior[key])
+        os.remove(filename)
+
     def test_create_default_prior(self):
         default = bilby.gw.prior.BBHPriorDict()
         minima = all(
-- 
GitLab