Skip to content
Snippets Groups Projects
Commit 1be08c59 authored by Matthew Pitkin's avatar Matthew Pitkin Committed by Rhys Green
Browse files

Fixes to PyMC3 autoinitialisation of NUTS step method

parent a4467c22
No related branches found
No related tags found
1 merge request!299Fixing some argument bugs in pymc3.py.
......@@ -424,9 +424,21 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
methodslist = []
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
......@@ -438,12 +450,15 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -454,12 +469,15 @@ class Pymc3(MCMCSampler):
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -475,47 +493,48 @@ class Pymc3(MCMCSampler):
compound = []
for sms in self.step_method:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
self.kwargs['step'] = compound
else:
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
with self.pymc3_model:
# perform the sampling
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment