Skip to content
Snippets Groups Projects
Commit 53d27a61 authored by Matthew Pitkin's avatar Matthew Pitkin
Browse files

Allow PyMC3 step name to be passed to code

parent 92e4b033
No related branches found
No related tags found
1 merge request!139Add PyMC3 sampler
......@@ -944,6 +944,34 @@ class Ptemcee(Emcee):
class Pymc3(Sampler):
""" https://docs.pymc.io/ """
@property
def kwargs(self):
""" Ensures that proper keyword arguments are used for the Pymc3 sampler.
Returns
-------
dict: Keyword arguments used for the Nestle Sampler
"""
return self.__kwargs
@kwargs.setter
def kwargs(self, kwargs):
self.__kwargs = dict()
self.__kwargs.update(kwargs)
# set some defaults
# set the number of draws
self.draws = 1000 if 'draws' not in self.__kwargs else self.__kwargs.pop('draws')
if 'chains' not in self.__kwargs:
self.__kwargs['chains'] = 2
self.chains = self.__kwargs['chains']
if 'cores' not in self.__kwargs:
self.__kwargs['cores'] = 1
def setup_prior_mapping(self):
"""
Set the mapping between predefined tupak priors and the equivalent
......@@ -1108,44 +1136,57 @@ class Pymc3(Sampler):
except ImportError:
raise ImportError("You must have Theano installed to use PyMC3")
class Pymc3PowerLaw(pymc3.Continuous):
def __init__(self, lower, upper, alpha, testval=1):
falpha = alpha
self.lower = lower = tt.as_tensor_variable(floatX(lower))
self.upper = upper = tt.as_tensor_variable(floatX(upper))
self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
if self.priors[key].alpha < -1.:
# use Pareto distribution
palpha = -(1. + self.priors[key].alpha)
if falpha == -1:
self.norm = 1./(tt.log(self.upper/self.lower))
else:
self.norm = (1. + self.alpha) / (tt.pow(self.upper, (1. + self.alpha))
- tt.pow(self.lower, (1. + self.alpha)))
transform = pymc3.distributions.transforms.interval(lower, upper)
super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval)
def logp(self, value):
upper = self.upper
lower = self.lower
alpha = self.alpha
return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper)
return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha)
return pymc3.Bound(pymc3.Pareto, upper=self.priors[key].minimum)(key, alpha=palpha, m=self.priors[key].maximum)
else:
class Pymc3PowerLaw(pymc3.Continuous):
def __init__(self, lower, upper, alpha, testval=1):
print(lower, upper, alpha)
falpha = alpha
self.lower = lower = tt.as_tensor_variable(floatX(lower))
self.upper = upper = tt.as_tensor_variable(floatX(upper))
self.alpha = alpha = tt.as_tensor_variable(floatX(alpha))
if falpha == -1:
self.norm = 1./(tt.log(self.upper/self.lower))
else:
beta = (1. + self.alpha)
self.norm = 1. /(beta * (tt.pow(self.upper, beta)
- tt.pow(self.lower, beta)))
transform = pymc3.distributions.transforms.interval(lower, upper)
super(Pymc3PowerLaw, self).__init__(transform=transform, testval=testval)
def logp(self, value):
upper = self.upper
lower = self.lower
alpha = self.alpha
return pymc3.distributions.dist_math.bound(self.alpha*tt.log(value) + tt.log(self.norm), lower <= value, value <= upper)
return Pymc3PowerLaw(key, lower=self.priors[key].minimum, upper=self.priors[key].maximum, alpha=self.priors[key].alpha)
else:
raise ValueError("Prior for '{}' is not a Power Law".format(key))
def _run_external_sampler(self):
pymc3 = self.external_sampler
# set kwargs
self.draws = self.kwargs.get('draws', 1000)
self.chains = self.kwargs.get('chains', 2)
self.cores = self.kwargs.get('cores', 1)
self.tune = self.kwargs.get('tune', 1000) # burn in samples
self.discard_tuned_samples = self.kwargs.get('discard_tuned_samples',
True)
# set the step method
from pymc3.sampling import STEP_METHODS
step_methods = {m.__name__.lower(): m.__name__ for m in STEP_METHODS}
if 'step' in self.__kwargs:
step_method = self.__kwargs.pop('step').lower()
if step_method not in step_methods:
raise ValueError("Using invalid step method '{}'".format(step_method))
else:
step_method = None
# initialise the PyMC3 model
model = pymc3.Model()
......@@ -1157,10 +1198,10 @@ class Pymc3(Sampler):
# set the likelihood function
self.set_likelihood()
sm = None if step_method is None else pymc3.__dict__[step_methods[step_method]]()
# perform the sampling
trace = pymc3.sample(self.draws, tune=self.tune, cores=self.cores,
chains=self.chains,
discard_tuned_samples=self.discard_tuned_samples)
trace = pymc3.sample(self.draws, step=sm, **self.kwargs)
nparams = len(self.priors.keys())
nsamples = len(trace)*self.chains
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment