From 53d27a61f8d7f1a59fe1cbd5e4f1af44d589d33d Mon Sep 17 00:00:00 2001 From: Matthew Pitkin <matthew.pitkin@ligo.org> Date: Tue, 7 Aug 2018 17:09:12 +0100 Subject: [PATCH] Allow PyMC3 step name to be passed to code --- tupak/core/sampler.py | 109 +++++++++++++++++++++++++++++------------- 1 file changed, 75 insertions(+), 34 deletions(-) diff --git a/tupak/core/sampler.py b/tupak/core/sampler.py index 57ae0a393..a0777af7f 100644 --- a/tupak/core/sampler.py +++ b/tupak/core/sampler.py @@ -944,6 +944,34 @@ class Ptemcee(Emcee): class Pymc3(Sampler): """ https://docs.pymc.io/ """ + @property + def kwargs(self): + """ Ensures that proper keyword arguments are used for the Pymc3 sampler. + + Returns + ------- + dict: Keyword arguments used for the Nestle Sampler + + """ + return self.__kwargs + + @kwargs.setter + def kwargs(self, kwargs): + self.__kwargs = dict() + self.__kwargs.update(kwargs) + + # set some defaults + + # set the number of draws + self.draws = 1000 if 'draws' not in self.__kwargs else self.__kwargs.pop('draws') + + if 'chains' not in self.__kwargs: + self.__kwargs['chains'] = 2 + self.chains = self.__kwargs['chains'] + + if 'cores' not in self.__kwargs: + self.__kwargs['cores'] = 1 + def setup_prior_mapping(self): """ Set the mapping between predefined tupak priors and the equivalent @@ -1108,44 +1136,57 @@ class Pymc3(Sampler): except ImportError: raise ImportError("You must have Theano installed to use PyMC3") - class Pymc3PowerLaw(pymc3.Continuous): - def __init__(self, lower, upper, alpha, testval=1): - falpha = alpha - self.lower = lower = tt.as_tensor_variable(floatX(lower)) - self.upper = upper = tt.as_tensor_variable(floatX(upper)) - self.alpha = alpha = tt.as_tensor_variable(floatX(alpha)) + if self.priors[key].alpha < -1.: + # use Pareto distribution + palpha = -(1. + self.priors[key].alpha) - if falpha == -1: - self.norm = 1./(tt.log(self.upper/self.lower)) - else: - self.norm = (1. + self.alpha) / (tt.pow(self.upper, (1. + self.alpha)) - - tt.pow(self.lower, (1. + self.alpha))) - - transform = pymc3.distributions.transforms.interval(lower, upper) - - super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval) - - def logp(self, value): - upper = self.upper - lower = self.lower - alpha = self.alpha - - return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper) - - return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha) + return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)(key, alpha=palpha, m=self.priors[key].maximum) + else: + class Pymc3PowerLaw(pymc3.Continuous): + def __init__(self, lower, upper, alpha, testval=1): + print(lower, upper, alpha) + + falpha = alpha + self.lower = lower = tt.as_tensor_variable(floatX(lower)) + self.upper = upper = tt.as_tensor_variable(floatX(upper)) + self.alpha = alpha = tt.as_tensor_variable(floatX(alpha)) + + if falpha == -1: + self.norm = 1./(tt.log(self.upper/self.lower)) + else: + beta = (1. + self.alpha) + self.norm = 1. /(beta * (tt.pow(self.upper, beta) + - tt.pow(self.lower, beta))) + + transform = pymc3.distributions.transforms.interval(lower, upper) + + super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval) + + def logp(self, value): + upper = self.upper + lower = self.lower + alpha = self.alpha + + return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper) + + return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha) else: raise ValueError("Prior for '{}' is not a Power Law".format(key)) def _run_external_sampler(self): pymc3 = self.external_sampler - # set kwargs - self.draws = self.kwargs.get('draws', 1000) - self.chains = self.kwargs.get('chains', 2) - self.cores = self.kwargs.get('cores', 1) - self.tune = self.kwargs.get('tune', 1000) # burn in samples - self.discard_tuned_samples = self.kwargs.get('discard_tuned_samples', - True) + # set the step method + from pymc3.sampling import STEP_METHODS + + step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS} + if 'step' in self.__kwargs: + step_method = self.__kwargs.pop('step').lower() + + if step_method not in step_methods: + raise ValueError("Using invalid step method '{}'".format(step_method)) + else: + step_method = None # initialise the PyMC3 model model = pymc3.Model() @@ -1157,10 +1198,10 @@ class Pymc3(Sampler): # set the likelihood function self.set_likelihood() + sm = None if step_method is None else pymc3.__dict__[step_methods[step_method]]() + # perform the sampling - trace = pymc3.sample(self.draws, tune=self.tune, cores=self.cores, - chains=self.chains, - discard_tuned_samples=self.discard_tuned_samples) + trace = pymc3.sample(self.draws, step=sm, **self.kwargs) nparams = len(self.priors.keys()) nsamples = len(trace)*self.chains -- GitLab