From b1a890e4bfdc5e7bdf8504cf1b2e413c96efb711 Mon Sep 17 00:00:00 2001
From: Bruce Edelman <bruce.edelman@ligo.org>
Date: Mon, 13 Jan 2020 15:47:00 -0600
Subject: [PATCH] Generalise comp mass conv (Resolve #436)

---
 bilby/gw/prior.py     | 45 +++++++++++++++++++++----------------------
 test/gw_prior_test.py |  1 +
 2 files changed, 23 insertions(+), 23 deletions(-)

diff --git a/bilby/gw/prior.py b/bilby/gw/prior.py
index 2154cf1ad..12f69717f 100644
--- a/bilby/gw/prior.py
+++ b/bilby/gw/prior.py
@@ -27,7 +27,9 @@ class BilbyPriorConversionError(Exception):
 
 
 def convert_to_flat_in_component_mass_prior(result, fraction=0.25):
-    """ Converts samples flat in chirp-mass and mass-ratio to flat in component mass
+    """ Converts samples with a defined prior in chirp-mass and mass-ratio to flat in component mass by resampling with
+    the posterior with weights defined as ratio in new:old prior values times the jacobian which for
+    F(mc, q) -> G(m1, m2) is defined as J := m1^2 / mc
 
     Parameters
     ----------
@@ -40,35 +42,32 @@ def convert_to_flat_in_component_mass_prior(result, fraction=0.25):
 
     """
     if getattr(result, "priors") is not None:
-        if isinstance(getattr(result.priors, "chirp_mass", None), Uniform) is False:
-            BilbyPriorConversionError("chirp mass prior should be Uniform")
-        if isinstance(getattr(result.priors, "mass_ratio", None), Uniform) is False:
-            BilbyPriorConversionError("mass_ratio prior should be Uniform")
-        if isinstance(getattr(result.priors, "mass_1", None), Constraint):
-            BilbyPriorConversionError("mass_1 prior should be a Constraint")
-        if isinstance(getattr(result.priors, "mass_2", None), Constraint):
-            BilbyPriorConversionError("mass_2 prior should be a Constraint")
+        for key in ['chirp_mass', 'mass_ratio']:
+            if key not in result.priors.keys():
+                BilbyPriorConversionError("{} Prior not found in result object".format(key))
+            if isinstance(result.priors[key], Constraint):
+                BilbyPriorConversionError("{} Prior should not be a Constraint".format(key))
+        for key in ['mass_1', 'mass_2']:
+            if not isinstance(result.priors[key], Constraint):
+                BilbyPriorConversionError("{} Prior should be a Constraint Prior".format(key))
     else:
         BilbyPriorConversionError("No prior in the result: unable to convert")
 
     result = copy.copy(result)
     priors = result.priors
+    old_priors = copy.copy(result.priors)
     posterior = result.posterior
 
-    priors["chirp_mass"] = Constraint(
-        priors["chirp_mass"].minimum, priors["chirp_mass"].maximum,
-        "chirp_mass", latex_label=priors["chirp_mass"].latex_label)
-    priors["mass_ratio"] = Constraint(
-        priors["mass_ratio"].minimum, priors["mass_ratio"].maximum,
-        "mass_ratio", latex_label=priors["mass_ratio"].latex_label)
-    priors["mass_1"] = Uniform(
-        priors["mass_1"].minimum, priors["mass_1"].maximum, "mass_1",
-        latex_label=priors["mass_1"].latex_label, unit="$M_{\odot}$")
-    priors["mass_2"] = Uniform(
-        priors["mass_2"].minimum, priors["mass_2"].maximum, "mass_2",
-        latex_label=priors["mass_2"].latex_label, unit="$M_{\odot}$")
-
-    weights = posterior["mass_1"] ** 2 / posterior["chirp_mass"]
+    for key in ['chirp_mass', 'mass_ratio']:
+        priors[key] = Constraint(priors[key].minimum, priors[key].maximum, key, latex_label=priors[key].latex_label)
+    for key in ['mass_1', 'mass_2']:
+        priors[key] = Uniform(priors[key].minimum, priors[key].maximum, key, latex_label=priors[key].latex_label,
+                              unit="$M_{\odot}$")
+
+    weights = np.array(result.get_weights_by_new_prior(old_priors, priors,
+                                                       prior_names=['chirp_mass', 'mass_ratio', 'mass_1', 'mass_2']))
+    jacobian = posterior["mass_1"] ** 2 / posterior["chirp_mass"]
+    weights = jacobian * weights
     result.posterior = posterior.sample(frac=fraction, weights=weights)
 
     logger.info("Resampling posterior to flat-in-component mass")
diff --git a/test/gw_prior_test.py b/test/gw_prior_test.py
index f564ead63..28bbaf1da 100644
--- a/test/gw_prior_test.py
+++ b/test/gw_prior_test.py
@@ -143,6 +143,7 @@ class TestPriorConversion(unittest.TestCase):
             mass_1=Constraint(name='mass_1', minimum=mass_2[0], maximum=mass_2[1])))
 
         lalinf_prior = BBHPriorDict(dictionary=dict(
+            mass_ratio=Constraint(name='mass_ratio', minimum=mass_ratio[0], maximum=mass_ratio[1]),
             chirp_mass=Constraint(name='chirp_mass', minimum=chirp_mass[0], maximum=chirp_mass[1]),
             mass_2=Uniform(name='mass_2', minimum=mass_1[0], maximum=mass_1[1]),
             mass_1=Uniform(name='mass_1', minimum=mass_2[0], maximum=mass_2[1])))
-- 
GitLab