Skip to content
Snippets Groups Projects
Commit 5c494930 authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Merge branch 'fix_step_kwargs' into 'master'

pymc3.py: explicitly set step method arguments

See merge request !293
parents 712192db 52cdf6fc
No related branches found
No related tags found
1 merge request!293pymc3.py: explicitly set step method arguments
Pipeline #39980 passed
......@@ -425,6 +425,10 @@ class Pymc3(MCMCSampler):
# set the prior
self.set_prior()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
# set the step method
if isinstance(self.step_method, (dict, OrderedDict)):
# create list of step methods (any not given will default to NUTS)
......@@ -435,19 +439,59 @@ class Pymc3(MCMCSampler):
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
curmethod = self.step_method[key].lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]](vars=[self.pymc3_priors[key]], **args))
else:
with self.pymc3_model:
# check for a compound step list
if isinstance(self.step_method, list):
compound = []
for sms in self.step_method:
compound.append(pymc3.__dict__[step_methods[sms.lower()]]())
curmethod = sms.lower()
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
compound.append(pymc3.__dict__[step_methods[curmethod]](**args))
else:
self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]()
self.kwargs['step'] = None
if self.step_method is not None:
curmethod = self.step_method.lower()
args = {}
if curmethod == 'nuts':
if nuts_kwargs is not None:
args = nuts_kwargs
else:
args = step_kwargs.get('nuts', {})
else:
args = step_kwargs.get(curmethod, {})
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs and nuts_kwargs if no step methods are set
self.kwargs['nuts_kwargs'] = nuts_kwargs
self.kwargs['step_kwargs'] = step_kwargs
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment