From 53d27a61f8d7f1a59fe1cbd5e4f1af44d589d33d Mon Sep 17 00:00:00 2001
From: Matthew Pitkin <matthew.pitkin@ligo.org>
Date: Tue, 7 Aug 2018 17:09:12 +0100
Subject: [PATCH] Allow PyMC3 step name to be passed to code

---
 tupak/core/sampler.py | 109 +++++++++++++++++++++++++++++-------------
 1 file changed, 75 insertions(+), 34 deletions(-)

diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py
index 57ae0a393..a0777af7f 100644
--- a/tupak/core/sampler.py
+++ b/tupak/core/sampler.py
@@ -944,6 +944,34 @@ class Ptemcee(Emcee):
 class Pymc3(Sampler):
     """ https://docs.pymc.io/ """
 
+    @property
+    def kwargs(self):
+        """ Ensures that proper keyword arguments are used for the Pymc3 sampler.
+
+        Returns
+        -------
+        dict: Keyword arguments used for the Nestle Sampler
+
+        """
+        return self.__kwargs
+
+    @kwargs.setter
+    def kwargs(self, kwargs):
+        self.__kwargs = dict()
+        self.__kwargs.update(kwargs)
+
+        # set some defaults
+
+        # set the number of draws
+        self.draws = 1000 if 'draws' not in self.__kwargs else self.__kwargs.pop('draws')
+
+        if 'chains' not in self.__kwargs:
+            self.__kwargs['chains'] = 2
+            self.chains = self.__kwargs['chains']
+
+        if 'cores' not in self.__kwargs:
+            self.__kwargs['cores'] = 1
+
     def setup_prior_mapping(self):
         """
         Set the mapping between predefined tupak priors and the equivalent
@@ -1108,44 +1136,57 @@ class Pymc3(Sampler):
             except ImportError:
                 raise ImportError("You must have Theano installed to use PyMC3")
 
-            class Pymc3PowerLaw(pymc3.Continuous):
-                def __init__(self, lower, upper, alpha, testval=1):
-                    falpha = alpha
-                    self.lower = lower = tt.as_tensor_variable(floatX(lower))
-                    self.upper = upper = tt.as_tensor_variable(floatX(upper))
-                    self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
+            if self.priors[key].alpha < -1.:
+                # use Pareto distribution
+                palpha = -(1. + self.priors[key].alpha)
 
-                    if falpha == -1:
-                        self.norm = 1./(tt.log(self.upper/self.lower))
-                    else:
-                        self.norm = (1. + self.alpha) / (tt.pow(self.upper, (1. + self.alpha)) 
-                                          - tt.pow(self.lower, (1. + self.alpha)))
-
-                    transform = pymc3.distributions.transforms.interval(lower, upper)
-
-                    super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval)
-
-                def logp(self, value):
-                    upper = self.upper
-                    lower = self.lower
-                    alpha = self.alpha
-
-                    return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper)
-
-            return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha)
+                return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)(key, alpha=palpha, m=self.priors[key].maximum)
+            else:
+                class Pymc3PowerLaw(pymc3.Continuous):
+                    def __init__(self, lower, upper, alpha, testval=1):
+                        print(lower, upper, alpha)
+                        
+                        falpha = alpha
+                        self.lower = lower = tt.as_tensor_variable(floatX(lower))
+                        self.upper = upper = tt.as_tensor_variable(floatX(upper))
+                        self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
+
+                        if falpha == -1:
+                            self.norm = 1./(tt.log(self.upper/self.lower))
+                        else:
+                            beta = (1. + self.alpha)
+                            self.norm = 1. /(beta * (tt.pow(self.upper, beta) 
+                                          - tt.pow(self.lower, beta)))
+
+                        transform = pymc3.distributions.transforms.interval(lower, upper)
+
+                        super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval)
+
+                    def logp(self, value):
+                        upper = self.upper
+                        lower = self.lower
+                        alpha = self.alpha
+
+                        return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper)
+
+                return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha)
         else:
             raise ValueError("Prior for '{}' is not a Power Law".format(key))
 
     def _run_external_sampler(self):
         pymc3 = self.external_sampler
 
-        # set kwargs
-        self.draws = self.kwargs.get('draws', 1000)
-        self.chains = self.kwargs.get('chains', 2)
-        self.cores = self.kwargs.get('cores', 1)
-        self.tune = self.kwargs.get('tune', 1000) # burn in samples
-        self.discard_tuned_samples = self.kwargs.get('discard_tuned_samples',
-                                                     True)
+        # set the step method
+        from pymc3.sampling import STEP_METHODS
+
+        step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
+        if 'step' in self.__kwargs:
+            step_method = self.__kwargs.pop('step').lower()
+
+            if step_method not in step_methods:
+                raise ValueError("Using invalid step method '{}'".format(step_method))
+        else:
+            step_method = None
 
         # initialise the PyMC3 model
         model = pymc3.Model()
@@ -1157,10 +1198,10 @@ class Pymc3(Sampler):
             # set the likelihood function
             self.set_likelihood()
 
+            sm = None if step_method is None else pymc3.__dict__[step_methods[step_method]]()
+
             # perform the sampling
-            trace = pymc3.sample(self.draws, tune=self.tune, cores=self.cores,
-                chains=self.chains, 
-                discard_tuned_samples=self.discard_tuned_samples)
+            trace = pymc3.sample(self.draws, step=sm, **self.kwargs)
 
         nparams = len(self.priors.keys())
         nsamples = len(trace)*self.chains
-- 
GitLab