Skip to content
Snippets Groups Projects
Commit 712192db authored by Colm Talbot's avatar Colm Talbot
Browse files

Merge branch 'pymc3_compoundstep' into 'master'

Attempt to allow compound step methods for PyMC3

See merge request !289
parents 1ce4e073 2de2a349
No related branches found
No related tags found
No related merge requests found
......@@ -36,11 +36,13 @@ class Pymc3(MCMCSampler):
chains.
step: str, dict
Provide a step method name, or dictionary of step method names keyed to
particular variable names (these are case insensitive). If no method is
provided for any particular variable then PyMC3 will automatically
decide upon a default, with the first option being the NUTS sampler.
The currently allowed methods are 'NUTS', 'HamiltonianMC',
'Metropolis', 'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and
particular variable names (these are case insensitive). If passing a
dictionary of methods, the values keyed on particular variables can be
lists of methods to form compound steps. If no method is provided for
any particular variable then PyMC3 will automatically decide upon a
default, with the first option being the NUTS sampler. The currently
allowed methods are 'NUTS', 'HamiltonianMC', 'Metropolis',
'BinaryMetropolis', 'BinaryGibbsMetropolis', 'Slice', and
'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step
method function itself here as it is outside of the model context
manager.
......@@ -393,13 +395,27 @@ class Pymc3(MCMCSampler):
if key not in self.__search_parameter_keys:
raise ValueError("Setting a step method for an unknown parameter '{}'".format(key))
else:
if self.step_method[key].lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
# check if using a compound step (a list of step
# methods for a particular parameter)
if isinstance(self.step_method[key], list):
sms = self.step_method[key]
else:
sms = [self.step_method[key]]
for sm in sms:
if sm.lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method[key]))
else:
self.step_method = self.step_method.lower()
# check if using a compound step (a list of step
# methods for a particular parameter)
if isinstance(self.step_method, list):
sms = self.step_method
else:
sms = [self.step_method]
if self.step_method not in step_methods:
raise ValueError("Using invalid step method '{}'".format(self.step_method))
for i in range(len(sms)):
if sms[i].lower() not in step_methods:
raise ValueError("Using invalid step method '{}'".format(sms[i]))
else:
self.step_method = None
......@@ -415,11 +431,23 @@ class Pymc3(MCMCSampler):
self.kwargs['step'] = []
with self.pymc3_model:
for key in self.step_method:
curmethod = self.step_method[key].lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
# check for a compound step list
if isinstance(self.step_method[key], list):
for sms in self.step_method[key]:
curmethod = sms.lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
else:
curmethod = self.step_method[key].lower()
self.kwargs['step'].append(pymc3.__dict__[step_methods[curmethod]]([self.pymc3_priors[key]]))
else:
with self.pymc3_model:
self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method]]()
# check for a compound step list
if isinstance(self.step_method, list):
compound = []
for sms in self.step_method:
compound.append(pymc3.__dict__[step_methods[sms.lower()]]())
else:
self.kwargs['step'] = None if self.step_method is None else pymc3.__dict__[step_methods[self.step_method.lower()]]()
# if a custom log_likelihood function requires a `sampler` argument
# then use that log_likelihood function, with the assumption that it
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment