From 765090051a82f294f99c8b21e2317cbe632b0524 Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Sat, 4 Aug 2018 23:00:29 +0100
Subject: [PATCH] Move setting of PyMC3 priors:  - move the setting of the
 PyMC3 priors into the Pymc3 sampler    class

---
 tupak/core/likelihood.py | 20 +++++++--------
 tupak/core/prior.py      | 55 ++++++++++++++++++++++++++--------------
 tupak/core/sampler.py    | 40 ++++++++++++++++++++++++++++-
 3 files changed, 85 insertions(+), 30 deletions(-)

diff --git a/tupak/core/likelihood.py b/tupak/core/likelihood.py
index c8d57fca3..575e2d2a7 100644
--- a/tupak/core/likelihood.py
+++ b/tupak/core/likelihood.py
@@ -124,20 +124,20 @@ class GaussianLikelihood(Likelihood):
         if samplername != 'pymc3':
             raise ValueError("Only use this class method for PyMC3 sampler")
 
-        model_parameters = dict()
-        for key in sampler.priors:
-            if key == 'sigma' and self.sigma is None:
-                self.sigma = sampler.priors[key].pymc3_prior(sampler)
+        if 'sigma' in sampler.pymc3_priors:
+            # if sigma is suppled use that value
+            if self.sigma is None:
+                self.sigma = sampler.pymc3_priors.pop('sigma')
             else:
-                if key in self.function_keys:
-                    model_parameters[key] = sampler.priors[key].pymc3_prior(sampler)
-                else:
-                    raise ValueError("Prior key '{}' is not a function key!".format(key))
+                del sampler.pymc3_priors['sigma']
 
-        model = self.function(self.x, **model_parameters)
+        for key in sampler.pymc3_priors:
+            if key not in self.function_keys:
+                raise ValueError("Prior key '{}' is not a function key!".format(key))
 
-        return sampler.external_sampler.Normal('likelihood', mu=model, sd=self.sigma, observed=self.y)
+        model = self.function(self.x, **sampler.pymc3_priors)
 
+        return sampler.external_sampler.Normal('likelihood', mu=model, sd=self.sigma, observed=self.y)
 
 class PoissonLikelihood(Likelihood):
     def __init__(self, x, counts, func):
diff --git a/tupak/core/prior.py b/tupak/core/prior.py
index 41dd444dc..d48617970 100644
--- a/tupak/core/prior.py
+++ b/tupak/core/prior.py
@@ -467,22 +467,29 @@ class Prior(object):
             label = self.name
         return label
 
-    def set_pymc3_prior(self, sampler, priortype, **kwargs):
+    def pymc3_prior(self, sampler):
+        """
+        'pymc3_prior' A user defined PyMC3 prior.
+
+        This should be overwritten by each subclass if needed.
+
+        Parameters
+        ----------
+        val: float
+            A random number between 0 and 1
+
+        Returns
+        -------
+        None
+
+        """
+
         from tupak.core.sampler import Sampler
 
         if not isinstance(sampler, Sampler):
             raise ValueError("'sampler' is not a Sampler class")
 
-        try:
-            samplername = sampler.external_sampler.__name__
-        except ValueError:
-            raise ValueError("Sampler's 'external_sampler' has not been initialised")
-
-        if samplername != 'pymc3':
-            raise ValueError("Only use this class method for PyMC3 sampler")
-        
-        if priortype in sampler.external_sampler.__dict__:
-            return sampler.external_sampler.__dict__[priortype](self.name, **kwargs)
+        return None
 
 
 class DeltaFunction(Prior):
@@ -644,6 +651,10 @@ class Uniform(Prior):
         """
         Prior.__init__(self, name, latex_label, minimum, maximum)
 
+        # set PyMC3 Uniform distribution attributes
+        self.lower = self.minimum
+        self.upper = self.maximum
+
     def rescale(self, val):
         Prior.test_valid_for_rescaling(val)
         return self.minimum + val * (self.maximum - self.minimum)
@@ -676,14 +687,6 @@ class Uniform(Prior):
         return scipy.stats.uniform.logpdf(val, loc=self.minimum,
                                           scale=self.maximum-self.minimum)
 
-    def pymc3_prior(self, sampler):
-        priortype = 'Uniform'
-        priorargs = {}
-        priorargs['lower'] = self.minimum
-        priorargs['upper'] = self.maximum
-
-        return self.set_pymc3_prior(sampler, priortype, **priorargs)
-
 
 class LogUniform(PowerLaw):
 
@@ -855,6 +858,20 @@ class Gaussian(Prior):
         return Prior._subclass_repr_helper(self, subclass_args=['mu', 'sigma'])
 
 
+class Normal(Gaussian):
+    def __init__(self, mu, sigma, name=None, latex_label=None):
+        """A copy of the Gaussian prior, but with "Normal" name to copy that
+        used for the distribution in PyMC3.
+
+        """
+
+        Gaussian.__init__(self, mu, sigma, name, latex_label)
+
+        # set argument names used in PyMC3 distribution
+        self.mu = mu
+        self.sd = sigma
+
+
 class TruncatedGaussian(Prior):
 
     def __init__(self, mu, sigma, minimum, maximum, name=None, latex_label=None):
diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py
index 041b2c14f..b3cee6884 100644
--- a/tupak/core/sampler.py
+++ b/tupak/core/sampler.py
@@ -958,6 +958,8 @@ class Pymc3(Sampler):
         model = pymc3.Model()
 
         with model:
+            self.set_prior()
+
             likelihood = self.likelihood.pymc3_likelihood(self)
 
             # perform the sampling
@@ -965,7 +967,7 @@ class Pymc3(Sampler):
                 chains=self.chains, 
                 discard_tuned_samples=self.discard_tuned_samples)
 
-        nparams = int(len(trace.varnames)/self.chains)
+        nparams = len(trace.varnames)
         nsamples = len(trace)*self.chains
 
         self.result.samples = np.zeros((nsamples, nparams))
@@ -978,6 +980,42 @@ class Pymc3(Sampler):
         self.result.log_evidence_err = np.nan
         return self.result
 
+    def set_prior(self):
+        """ Set the PyMC3 prior distributions.
+
+        """
+
+        self.pymc3_priors = dict()
+
+        # set the parameter prior distributions
+        for key in self.priors:
+            # if the prior contains a pymc3_prior method use that otherwise try
+            # and find the PyMC3 distribution
+            if self.priors[key].pymc3_prior(self) is not None:
+                self.pymc3_priors[key] = self.priors[key].pymc3_prior(self)
+            else:
+                # use Prior distribution name
+                distname = self.priors[key].__class__.__name__
+
+                # check whether name is a PyMC3 distribution
+                if distname in self.external_sampler.__dict__:
+                    # check the required arguments for the PyMC3 distribution
+                    reqargs = inspect.getargspec(self.external_sampler.__dict__[distname].__init__).args[1:]
+
+                    priorkwargs = dict()
+
+                    # check whether the Prior class has required attributes
+                    for arg in reqargs:
+                        if hasattr(self.priors[key], arg):
+                            priorkwargs[arg] = getattr(self.priors[key], arg)
+                        else:
+                            priorkwargs[arg] = None
+
+                    # set the prior
+                    self.pymc3_priors[key] = self.external_sampler.__dict__[distname](key, **priorkwargs)
+                else:
+                    raise ValueError("Prior '{}' is not a PyMC3 distribution.".format(distname))
+
     def calculate_autocorrelation(self, samples, c=3):
         """ Uses the `emcee.autocorr` module to estimate the autocorrelation
 
-- 
GitLab