Skip to content
Snippets Groups Projects
Commit 52cdf6fc authored by Matthew David Pitkin's avatar Matthew David Pitkin
Browse files

pymc3.py: explicitly set step method arguments

 - PyMC3 only applies the step_kwargs and nuts_kwargs arguments
   to the step methods if they are automaticallu allocated. Otherwise,
   it assumes you have already initialised the method with the
   required keywork arguments. This was not being done in bilby, so
   this patch fixes it and does explicitly initialise the step methods
   with values passed to step_kwargs or nuts_kwargs as required.
parent 712192db
No related branches found
No related tags found
No related merge requests found
......@@ -425,6 +425,10 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
# create list of step methods (any not given will default to NUTS)
......@@ -435,19 +439,59 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
with self.pymc3_model:
# check for a compound step list
if isinstance(self.step_method, list):
compound = []
for sms in self.step_method:
compound.append(pymc3.__dict__[step_methods[sms.lower()]]())
curmethod = sms.lower()
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
else:
self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]()
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment