Maintenance will be performed on git.ligo.org, chat.ligo.org, and docs.ligo.org, starting at approximately 10am CDT Tuesday 20 August 2019. The maintenance is expected to take around an hour and here will be two short periods of downtime, one at the beginning of the maintenance and another at the end.

Commit 1be08c59 authored by Matthew David Pitkin's avatar Matthew David Pitkin Committed by Rhys Green

Fixes to PyMC3 autoinitialisation of NUTS step method

parent a4467c22
......@@ -424,9 +424,21 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
methodslist = []
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
......@@ -438,12 +450,15 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -454,12 +469,15 @@ class Pymc3(MCMCSampler):
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
......@@ -475,47 +493,48 @@ class Pymc3(MCMCSampler):
compound = []
for sms in self.step_method:
curmethod = sms.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
self.kwargs['step'] = compound
else:
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
methodslist.append(curmethod)
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
elif step_kwargs is not None:
args = step_kwargs.get('nuts', {})
args = step_kwargs.pop('nuts', {})
# add values into nuts_kwargs
nuts_kwargs = args
else:
args = {}
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
# takes in a Pymc3 Sampler, with a pymc3_model attribute, and defines
# the likelihood within that context manager
likeargs = infer_args_from_method(self.likelihood.log_likelihood)
if 'sampler' in likeargs:
self.likelihood.log_likelihood(sampler=self)
else:
# set the likelihood function from predefined functions
self.set_likelihood()
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
with self.pymc3_model:
# perform the sampling
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment