Commit 1be08c59 authored by Matthew David Pitkin's avatar Matthew David Pitkin Committed by Rhys Green

Fixes to PyMC3 autoinitialisation of NUTS step method

parent a4467c22
......@@ -424,9 +424,21 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
methodslist = []
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
......@@ -438,12 +450,15 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -454,12 +469,15 @@ class Pymc3(MCMCSampler):
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -475,47 +493,48 @@ class Pymc3(MCMCSampler):
compound = []
for sms in self.step_method:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
self.kwargs['step'] = compound
else:
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
with self.pymc3_model:
# perform the sampling
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment