From c8717da3aab1141f15fb9c8c09a938df75e39915 Mon Sep 17 00:00:00 2001
From: Moritz <email@moritz-huebner.de>
Date: Wed, 2 Oct 2019 16:02:54 +0100
Subject: [PATCH] Improved naming and simplified some code

---
 bilby/core/prior.py                          | 189 +++++++++----------
 examples/other_examples/conditional_prior.py | 102 ++++++++++
 examples/other_examples/correlated_prior.py  |  62 ------
 3 files changed, 194 insertions(+), 159 deletions(-)
 create mode 100644 examples/other_examples/conditional_prior.py
 delete mode 100644 examples/other_examples/correlated_prior.py

diff --git a/bilby/core/prior.py b/bilby/core/prior.py
index 9960126e4..ce1a1f234 100644
--- a/bilby/core/prior.py
+++ b/bilby/core/prior.py
@@ -541,13 +541,13 @@ class ConditionalPriorDict(PriorDict):
 
     def _get_conditional_variables(self, key):
         conditional_variables = dict()
-        for k in self[key].conditional_variables:
+        for k in self[key].independent_variables:
             conditional_variables[k] = self[k].least_recently_sampled
         return conditional_variables
 
     def _check_conditions_resolved(self, key, sampled_keys):
         conditions_resolved = True
-        for k in self[key].conditional_variables:
+        for k in self[key].independent_variables:
             if k not in sampled_keys:
                 conditions_resolved = False
         return conditions_resolved
@@ -570,7 +570,7 @@ class ConditionalPriorDict(PriorDict):
         ls = []
         for key in sample:
             if key in self.conditioned_keys:
-                conditional_variables = dict([(k, sample[k]) for k in self[key].conditional_variables])
+                conditional_variables = dict([(k, sample[k]) for k in self[key].independent_variables])
                 ls.append(self[key].prob(sample[key], **conditional_variables))
             else:
                 ls.append(self[key].prob(sample[key]))
@@ -592,7 +592,7 @@ class ConditionalPriorDict(PriorDict):
         ls = []
         for key in sample:
             if key in self.conditioned_keys:
-                conditional_variables = dict([(k, sample[k]) for k in self[key].conditional_variables])
+                conditional_variables = dict([(k, sample[k]) for k in self[key].independent_variables])
                 ls.append(self[key].ln_prob(sample[key], **conditional_variables))
             else:
                 ls.append(self[key].ln_prob(sample[key]))
@@ -619,7 +619,7 @@ class ConditionalPriorDict(PriorDict):
         for key, ind in zip(unconditional_keys, unconditional_idxs):
             rescale_dict[key] = self[key].rescale(theta[ind])
         for key, ind in zip(conditional_keys, conditional_idxs):
-            conditional_variables = dict([(k, rescale_dict[k]) for k in self[key].conditional_variables])
+            conditional_variables = dict([(k, rescale_dict[k]) for k in self[key].independent_variables])
             rescale_dict[key] = self[key].rescale(theta[ind], **conditional_variables)
         ls = [rescale_dict[key] for key in keys]
         return ls
@@ -924,7 +924,7 @@ class Prior(object):
             dict_with_properties[key] = getattr(self, key)
         instantiation_dict = OrderedDict()
         for key in subclass_args:
-            if key == 'correlation_func':
+            if key == 'condition_func':
                 instantiation_dict[key] = str(dict_with_properties[key])
             else:
                 instantiation_dict[key] = dict_with_properties[key]
@@ -3304,98 +3304,93 @@ class MultivariateNormal(MultivariateGaussian):
         prior distribution."""
 
 
-class ConditionalPriorMixin(object):
-
-    def sample(self, size=None, **conditional_variables):
-        """Draw a sample from the prior
-
-        Parameters
-        ----------
-        size: int or tuple of ints, optional
-            See numpy.random.uniform docs
-
-        Returns
-        -------
-        float: A random number between 0 and 1, rescaled to match the distribution of this Prior
-
-        """
-        self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size), **conditional_variables)
-        return self.least_recently_sampled
-
-    def rescale(self, val, **conditional_variables):
-        """
-        'Rescale' a sample from the unit line element to the appropriate Gaussian prior.
-
-        This maps to the inverse CDF. This has been analytically solved for this case.
-        """
-        self.update_conditions(**conditional_variables)
-        return super(ConditionalPriorMixin, self).rescale(val)
-
-    def prob(self, val, **conditional_variables):
-        """Return the prior probability of val.
-
-        Parameters
-        ----------
-        val: Union[float, int, array_like]
-
-        Returns
-        -------
-        float: Prior probability of val
-        """
-        self.update_conditions(**conditional_variables)
-        return super(ConditionalPriorMixin, self).prob(val)
-
-    def ln_prob(self, val, **conditional_variables):
-        self.update_conditions(**conditional_variables)
-        return super(ConditionalPriorMixin, self).ln_prob(val)
-
-    def update_conditions(self, **conditional_variables):
-        conditioned_params = self.condition_func(self.reference_params, **conditional_variables)
-        for key, value in conditioned_params.items():
-            setattr(self, key, value)
-
-    def setup_conditional_prior(self, condition_func, **reference_params):
-        self.condition_func = condition_func
-        self._reference_params = reference_params
-
-    def get_conditionable_variables(self, all_vars):
-        vars = all_vars.copy()
-        for param in ['self', 'name', 'latex_label', 'unit', 'boundary', 'condition_func']:
-            try:
-                del vars[param]
-            except KeyError:
-                continue
-        vars = {key: all_vars[key] for key in vars}
-        return vars
-
-    @property
-    def reference_params(self):
-        return self._reference_params
-
-    @property
-    def condition_func(self):
-        return self._condition_func
-
-    @condition_func.setter
-    def condition_func(self, condition_func):
-        if not condition_func:
-            self._condition_func = lambda reference_param_dict, conditional_variables: reference_param_dict
-        else:
-            self._condition_func = condition_func
-
-    @property
-    def conditional_variables(self):
-        return infer_parameters_from_function(self.condition_func)
-
-
 def conditional_prior_factory(prior_class):
-    class ConditionalPrior(ConditionalPriorMixin, prior_class):
-        def __init__(self, **params):
-            condition_func = params['condition_func']
-            del params['condition_func']
-            super(ConditionalPrior, self).__init__(**params)
-            self.setup_conditional_prior(condition_func=condition_func,
-                                         **self.get_conditionable_variables(params))
+    class ConditionalPrior(prior_class):
+        def __init__(self, name=None, latex_label=None, unit=None,
+                     condition_func=None, boundary=None, **conditional_params):
+            super(ConditionalPrior, self).__init__(name=name, latex_label=latex_label,
+                                                   unit=unit, boundary=boundary, **conditional_params)
+            self.condition_func = condition_func
+            self._reference_params = conditional_params
+
+        def sample(self, size=None, **independent_variables):
+            """Draw a sample from the prior
+
+            Parameters
+            ----------
+            size: int or tuple of ints, optional
+                See numpy.random.uniform docs
+
+            Returns
+            -------
+            float: A random number between 0 and 1, rescaled to match the distribution of this Prior
+
+            """
+            self.least_recently_sampled = self.rescale(np.random.uniform(0, 1, size), **independent_variables)
+            return self.least_recently_sampled
+
+        def rescale(self, val, **independent_variables):
+            """
+            'Rescale' a sample from the unit line element to the appropriate Gaussian prior.
+
+            This maps to the inverse CDF. This has been analytically solved for this case.
+            """
+            self.update_conditions(**independent_variables)
+            return super(ConditionalPrior, self).rescale(val)
+
+        def prob(self, val, **independent_variables):
+            """Return the prior probability of val.
+
+            Parameters
+            ----------
+            val: Union[float, int, array_like]
+
+            Returns
+            -------
+            float: Prior probability of val
+            """
+            self.update_conditions(**independent_variables)
+            return super(ConditionalPrior, self).prob(val)
+
+        def ln_prob(self, val, **independent_variables):
+            self.update_conditions(**independent_variables)
+            return super(ConditionalPrior, self).ln_prob(val)
+
+        def update_conditions(self, **independent_variables):
+            conditioned_parameters = self.condition_func(self.reference_params, **independent_variables)
+            for key, value in conditioned_parameters.items():
+                setattr(self, key, value)
+
+        @property
+        def reference_params(self):
+            """
+
+            Returns
+            -------
+
+            """
+            return self._reference_params
+
+        @property
+        def condition_func(self):
+            """
+
+            Returns
+            -------
+
+            """
+            return self._condition_func
+
+        @condition_func.setter
+        def condition_func(self, condition_func):
+            if not condition_func:
+                self._condition_func = lambda reference_param_dict, conditional_variables: reference_param_dict
+            else:
+                self._condition_func = condition_func
+
+        @property
+        def independent_variables(self):
+            return infer_parameters_from_function(self.condition_func)
 
     return ConditionalPrior
 
diff --git a/examples/other_examples/conditional_prior.py b/examples/other_examples/conditional_prior.py
new file mode 100644
index 000000000..7b291bea7
--- /dev/null
+++ b/examples/other_examples/conditional_prior.py
@@ -0,0 +1,102 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import corner
+
+import bilby.gw.prior
+
+
+mass_1_min = 2
+mass_1_max = 50
+
+def condition_function(reference_params, mass_1):
+    return dict(mu=reference_params['mu'])
+    # return dict(minimum=np.maximum(reference_params['minimum'], mass_1_min / mass_1), maximum=reference_params['maximum'])
+
+# condition_function = lambda reference_params, mass_1: dict(minimum=np.maximum(reference_params['minimum'], mass_1_min / mass_1), maximum=reference_params['maximum'])
+
+mass_1 = bilby.core.prior.PowerLaw(alpha=-2.5, minimum=mass_1_min, maximum=mass_1_max, name='mass_1')
+mass_ratio = bilby.core.prior.ConditionalExponential(mu=2, name='mass_ratio',
+                                                     condition_func=condition_function)
+
+correlated_dict = bilby.core.prior.ConditionalPriorDict(dictionary=dict(mass_1=mass_1, mass_ratio=mass_ratio))
+
+res = correlated_dict.sample(100000)
+
+plt.hist(res['mass_1'], bins='fd', alpha=0.6, density=True, label='Sampled')
+plt.plot(np.linspace(2, 50, 200), correlated_dict['mass_1'].prob(np.linspace(2, 50, 200)), label='Powerlaw prior')
+plt.xlabel('$m_1$')
+plt.ylabel('$p(m_1)$')
+plt.loglog()
+plt.legend()
+plt.tight_layout()
+plt.show()
+plt.clf()
+
+
+plt.hist(res['mass_ratio'], bins='fd', alpha=0.6, density=True, label='Sampled')
+plt.xlabel('$q$')
+plt.ylabel('$p(q | m_1)$')
+plt.loglog()
+plt.legend()
+plt.tight_layout()
+plt.show()
+plt.clf()
+
+
+# mass_1 = bilby.core.prior.Uniform(5, 100)
+# mass_2 = bilby.gw.prior.CorrelatedSecondaryMassPrior(minimum=5, maximum=100)
+#
+# correlated_priors = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(mass_1=mass_1, mass_2=mass_2))
+#
+# samples = correlated_priors.sample(10)
+#
+# primary_masses = samples['mass_1']
+# secondary_masses = samples['mass_2']
+# for i in range(len(primary_masses)):
+#     if primary_masses[i] > secondary_masses[i]:
+#         print('True')
+#     else:
+#         print('False')
+#
+# sample = dict(mass_1=25, mass_2=20)
+# print(correlated_priors.prob(sample))
+
+
+# def correlation_func_a(mu, a=0):
+#     return mu + a**2 + 2 * a + 3
+#
+#
+# def correlation_func_b(mu, a=0, b=0):
+#     return mu + 0.01 * a**2 + 0.01 * b**2 + 0.01 * a * b + 0.1 * b + 3
+#
+#
+# a = bilby.core.prior.Gaussian(mu=0., sigma=1)
+# b = bilby.core.prior.CorrelatedGaussian(mu=0., sigma=1, correlation_func=correlation_func_a)
+# c = bilby.core.prior.CorrelatedGaussian(mu=0, sigma=1, correlation_func=correlation_func_b)
+#
+# correlated_uniform = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(a=a, b=b, c=c))
+#
+# samples = correlated_uniform.sample(1000000)
+#
+# samples = np.array([samples['a'], samples['b'], samples['c']]).T
+# corner.corner(np.array(samples))
+# plt.show()
+#
+#
+# def correlation_func_min_max(extrema_dict, a, b):
+#     maximum = extrema_dict['maximum'] + a**b
+#     minimum = np.log(b)
+#     return minimum, maximum
+#
+#
+# a = bilby.core.prior.Uniform(minimum=0, maximum=1)
+# b = bilby.core.prior.Uniform(minimum=1e-6, maximum=1e-1)
+# c = bilby.core.prior.CorrelatedUniform(minimum=0, maximum=1, correlation_func=correlation_func_min_max)
+#
+# correlated_uniform = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(a=a, b=b, c=c))
+#
+# samples = correlated_uniform.sample(1000000)
+# samples = np.array([samples['a'], samples['b'], samples['c']]).T
+# corner.corner(np.array(samples))
+# plt.show()
+#
\ No newline at end of file
diff --git a/examples/other_examples/correlated_prior.py b/examples/other_examples/correlated_prior.py
deleted file mode 100644
index 952e3d46d..000000000
--- a/examples/other_examples/correlated_prior.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import matplotlib.pyplot as plt
-import numpy as np
-import corner
-
-import bilby.gw.prior
-
-# mass_1 = bilby.core.prior.Uniform(5, 100)
-# mass_2 = bilby.gw.prior.CorrelatedSecondaryMassPrior(minimum=5, maximum=100)
-#
-# correlated_priors = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(mass_1=mass_1, mass_2=mass_2))
-#
-# samples = correlated_priors.sample(10)
-#
-# primary_masses = samples['mass_1']
-# secondary_masses = samples['mass_2']
-# for i in range(len(primary_masses)):
-#     if primary_masses[i] > secondary_masses[i]:
-#         print('True')
-#     else:
-#         print('False')
-#
-# sample = dict(mass_1=25, mass_2=20)
-# print(correlated_priors.prob(sample))
-
-
-def correlation_func_a(mu, a=0):
-    return mu + a**2 + 2 * a + 3
-
-
-def correlation_func_b(mu, a=0, b=0):
-    return mu + 0.01 * a**2 + 0.01 * b**2 + 0.01 * a * b + 0.1 * b + 3
-
-
-a = bilby.core.prior.Gaussian(mu=0., sigma=1)
-b = bilby.core.prior.CorrelatedGaussian(mu=0., sigma=1, correlation_func=correlation_func_a)
-c = bilby.core.prior.CorrelatedGaussian(mu=0, sigma=1, correlation_func=correlation_func_b)
-
-correlated_uniform = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(a=a, b=b, c=c))
-
-samples = correlated_uniform.sample(1000000)
-
-samples = np.array([samples['a'], samples['b'], samples['c']]).T
-corner.corner(np.array(samples))
-plt.show()
-
-
-def correlation_func_min_max(extrema_dict, a, b):
-    maximum = extrema_dict['maximum'] + a**b
-    minimum = np.log(b)
-    return minimum, maximum
-
-
-a = bilby.core.prior.Uniform(minimum=0, maximum=1)
-b = bilby.core.prior.Uniform(minimum=1e-6, maximum=1e-1)
-c = bilby.core.prior.CorrelatedUniform(minimum=0, maximum=1, correlation_func=correlation_func_min_max)
-
-correlated_uniform = bilby.core.prior.CorrelatedPriorDict(dictionary=dict(a=a, b=b, c=c))
-
-samples = correlated_uniform.sample(1000000)
-samples = np.array([samples['a'], samples['b'], samples['c']]).T
-corner.corner(np.array(samples))
-plt.show()
-- 
GitLab