From 5c930ea86d2464bf70811cae05dbea0ed25ae031 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Fri, 24 Apr 2020 17:20:41 -0500
Subject: [PATCH] Dirichlet priors

---
 bilby/core/likelihood.py            | 49 ++++++++++++++++++++++++-
 bilby/core/prior/conditional.py     | 56 +++++++++++++++++++++++++++++
 bilby/core/prior/dict.py            | 43 ++++++++++++++++++++++
 examples/core_examples/dirichlet.py | 33 +++++++++++++++++
 4 files changed, 180 insertions(+), 1 deletion(-)
 create mode 100644 examples/core_examples/dirichlet.py

diff --git a/bilby/core/likelihood.py b/bilby/core/likelihood.py
index 11fbd34cd..3441a1a6d 100644
--- a/bilby/core/likelihood.py
+++ b/bilby/core/likelihood.py
@@ -2,7 +2,7 @@ from __future__ import division, print_function
 import copy
 
 import numpy as np
-from scipy.special import gammaln
+from scipy.special import gammaln, xlogy
 from scipy.stats import multivariate_normal
 
 from .utils import infer_parameters_from_function
@@ -402,6 +402,53 @@ class StudentTLikelihood(Analytical1DLikelihood):
         self._nu = nu
 
 
+class Multinomial(Likelihood):
+    """
+    Likelihood for system with N discrete possibilities.
+    """
+
+    def __init__(self, data, n_dimensions, label="parameter_"):
+        """
+
+        Parameters
+        ----------
+        data: array-like
+            The number of objects in each class
+        n_dimensions: int
+            The number of classes
+        """
+        self.data = np.array(data)
+        self._total = np.sum(self.data)
+        super(Multinomial, self).__init__(dict())
+        self.n = n_dimensions
+        self.label = label
+        self._nll = None
+
+    def log_likelihood(self):
+        """
+        Since n - 1 parameters are sampled, the last parameter is 1 - the rest
+        """
+        probs = [self.parameters[self.label + str(ii)]
+                 for ii in range(self.n - 1)]
+        probs.append(1 - sum(probs))
+        return self._multinomial_ln_pdf(probs=probs)
+
+    def noise_log_likelihood(self):
+        """
+        Our null hypothesis is that all bins have probability 1 / nbins, i.e.,
+        no bin is preferred over any other.
+        """
+        if self._nll is None:
+            self._nll = self._multinomial_ln_pdf(probs=1 / self.n)
+        return self._nll
+
+    def _multinomial_ln_pdf(self, probs):
+        """Lifted from scipy.stats.multinomial._logpdf"""
+        ln_prob = gammaln(self._total + 1) + np.sum(
+            xlogy(self.data, probs) - gammaln(self.data + 1), axis=-1)
+        return ln_prob
+
+
 class AnalyticalMultidimensionalCovariantGaussian(Likelihood):
     """
         A multivariate Gaussian likelihood
diff --git a/bilby/core/prior/conditional.py b/bilby/core/prior/conditional.py
index 2e4a25e1b..cffbf2cec 100644
--- a/bilby/core/prior/conditional.py
+++ b/bilby/core/prior/conditional.py
@@ -225,6 +225,62 @@ ConditionalFermiDirac = conditional_prior_factory(FermiDirac)
 ConditionalInterped = conditional_prior_factory(Interped)
 
 
+class DirichletElement(ConditionalBeta):
+    """
+    Single element in a dirichlet distribution
+
+    The probability scales as
+
+    $p(x_order) \propto (x_max - x_order)^(n_dimensions - order - 2)$
+
+    for x_order < x_max, where x_max is the sum of x_i for i < order
+
+    Examples
+    --------
+    n_dimensions = 1:
+    p(x_0) \propto 1 ; 0 < x_0 < 1
+    n_dimensions = 2:
+    p(x_0) \propto (1 - x_0) ; 0 < x_0 < 1
+    p(x_1) \propto 1 ; 0 < x_1 < 1
+
+    Parameters
+    ----------
+    order: int
+        Order of this element of the dirichlet distribution.
+    n_dimensions: int
+        Total number of elements of the dirichlet distribution
+    label: str
+        Label for the dirichlet distribution.
+        This should be the same for all elements.
+    """
+
+    def __init__(self, order, n_dimensions, label):
+        super(DirichletElement, self).__init__(
+            minimum=0, maximum=1, alpha=1, beta=n_dimensions - order - 1,
+            name=label + str(order),
+            condition_func=self.dirichlet_condition
+        )
+        self.label = label
+        self.n_dimensions = n_dimensions
+        self.order = order
+        self._required_variables = [
+            label + str(ii) for ii in range(order)
+        ]
+        self.__class__.__name__ = 'Dirichlet'
+
+    def dirichlet_condition(self, reference_parms, **kwargs):
+        remaining = 1 - sum(
+            [kwargs[self.label + str(ii)] for ii in range(self.order)]
+        )
+        return dict(minimum=reference_parms["minimum"], maximum=remaining)
+
+    def __repr__(self):
+        return Prior.__repr__(self)
+
+    def get_instantiation_dict(self):
+        return Prior.get_instantiation_dict(self)
+
+
 class ConditionalPriorException(PriorException):
     """ General base class for all conditional prior exceptions """
 
diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py
index a375625b3..a6be2ba73 100644
--- a/bilby/core/prior/dict.py
+++ b/bilby/core/prior/dict.py
@@ -725,6 +725,49 @@ class ConditionalPriorDict(PriorDict):
         self._resolve_conditions()
 
 
+class DirichletPriorDict(ConditionalPriorDict):
+
+    def __init__(self, n_dim=None, label="dirichlet_"):
+        from .conditional import DirichletElement
+        self.n_dim = n_dim
+        self.label = label
+        super(DirichletPriorDict, self).__init__(dictionary=dict())
+        for ii in range(n_dim - 1):
+            self[label + "{}".format(ii)] = DirichletElement(
+                order=ii, n_dimensions=n_dim, label=label
+            )
+
+    def copy(self, **kwargs):
+        return self.__class__(n_dim=self.n_dim, label=self.label)
+
+    def _get_json_dict(self):
+        total_dict = dict()
+        total_dict["__prior_dict__"] = True
+        total_dict["__module__"] = self.__module__
+        total_dict["__name__"] = self.__class__.__name__
+        total_dict["n_dim"] = self.n_dim
+        total_dict["label"] = self.label
+        return total_dict
+
+    @classmethod
+    def _get_from_json_dict(cls, prior_dict):
+        try:
+            cls == getattr(
+                import_module(prior_dict["__module__"]),
+                prior_dict["__name__"])
+        except ImportError:
+            logger.debug("Cannot import prior module {}.{}".format(
+                prior_dict["__module__"], prior_dict["__name__"]
+            ))
+        except KeyError:
+            logger.debug("Cannot find module name to load")
+        for key in ["__module__", "__name__", "__prior_dict__"]:
+            if key in prior_dict:
+                del prior_dict[key]
+        obj = cls(**prior_dict)
+        return obj
+
+
 class ConditionalPriorDictException(PriorDictException):
     """ General base class for all conditional prior dict exceptions """
 
diff --git a/examples/core_examples/dirichlet.py b/examples/core_examples/dirichlet.py
new file mode 100644
index 000000000..494e37c21
--- /dev/null
+++ b/examples/core_examples/dirichlet.py
@@ -0,0 +1,33 @@
+import numpy as np
+import pandas as pd
+
+from bilby.core.likelihood import Multinomial
+from bilby.core.prior import DirichletPriorDict
+from bilby.core.sampler import run_sampler
+
+
+n_dim = 3
+label = "dirichlet_"
+priors = DirichletPriorDict(n_dim=n_dim, label=label)
+
+injection_parameters = dict(
+    dirichlet_0=1 / 3,
+    dirichlet_1=1 / 3,
+    dirichlet_2=1 / 3,
+)
+data = [injection_parameters[label + str(ii)] * 1000 for ii in range(n_dim)]
+
+likelihood = Multinomial(data=data, n_dimensions=n_dim, label=label)
+
+result = run_sampler(
+    likelihood=likelihood, priors=priors, nlive=100, walks=10,
+    label="multinomial", injection_parameters=injection_parameters
+)
+
+result.posterior[label + str(n_dim - 1)] = 1 - np.sum([result.posterior[key] for key in priors], axis=0)
+result.plot_corner(parameters=injection_parameters)
+
+samples = priors.sample(10000)
+samples[label + str(n_dim - 1)] = 1 - np.sum([samples[key] for key in samples], axis=0)
+result.posterior = pd.DataFrame(samples)
+result.plot_corner(parameters=[key for key in samples], filename="outdir/dirichlet_prior_corner.png")
-- 
GitLab