Skip to content
Snippets Groups Projects

Fix PyMC3 after deprecation of step_kwargs and nuts_kwargs

Files
2
from __future__ import absolute_import, print_function
from collections import OrderedDict
from distutils.version import StrictVersion
import numpy as np
@@ -45,8 +46,6 @@ class Pymc3(MCMCSampler):
'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step
method function itself here as it is outside of the model context
manager.
nuts_kwargs: dict
Keyword arguments for the NUTS sampler.
step_kwargs: dict
Options for steps methods other than NUTS. The dictionary is keyed on
lowercase step method names with values being dictionaries of keywords
@@ -56,13 +55,27 @@ class Pymc3(MCMCSampler):
default_kwargs = dict(
draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0,
chains=2, cores=1, tune=500, nuts_kwargs=None, step_kwargs=None, progressbar=True,
model=None, random_seed=None, discard_tuned_samples=True,
compute_convergence_checks=True)
chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None,
discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None,
step_kwargs=None,
)
default_nuts_kwargs = dict(
target_accept=None, max_treedepth=None, step_scale=None, Emax=None,
gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None,
scaling=None, is_cov=None, potential=None,
)
default_kwargs.update(default_nuts_kwargs)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False,
skip_import_verification=False, **kwargs):
# add default step kwargs
_, STEP_METHODS, _ = self._import_external_sampler()
self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS}
self.default_kwargs.update(self.default_step_kwargs)
super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
@@ -454,8 +467,35 @@ class Pymc3(MCMCSampler):
self.set_likelihood()
# get the step method keyword arguments
step_kwargs = self.kwargs.pop('step_kwargs')
nuts_kwargs = self.kwargs.pop('nuts_kwargs')
step_kwargs = self.kwargs.pop("step_kwargs")
if step_kwargs is not None:
# remove all individual default step kwargs if passed together using
# step_kwargs keywords
for key in self.default_step_kwargs:
self.kwargs.pop(key)
else:
# remove any None default step keywords and place others in step_kwargs
step_kwargs = {}
for key in self.default_step_kwargs:
if self.kwargs[key] is None:
self.kwargs.pop(key)
else:
step_kwargs[key] = self.kwargs.pop(key)
nuts_kwargs = self.kwargs.pop("nuts_kwargs")
if nuts_kwargs is not None:
# remove all individual default nuts kwargs if passed together using
# nuts_kwargs keywords
for key in self.default_nuts_kwargs:
self.kwargs.pop(key)
else:
# remove any None default nuts keywords and place others in nut_kwargs
nuts_kwargs = {}
for key in self.default_nuts_kwargs:
if self.kwargs[key] is None:
self.kwargs.pop(key)
else:
nuts_kwargs[key] = self.kwargs.pop(key)
methodslist = []
# set the step method
@@ -496,13 +536,19 @@ class Pymc3(MCMCSampler):
self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args)
else:
# re-add step_kwargs if no step methods are set
self.kwargs['step_kwargs'] = step_kwargs
if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
self.kwargs['step_kwargs'] = step_kwargs
# check whether only NUTS step method has been assigned
if np.all([sm.lower() == 'nuts' for sm in methodslist]):
# in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs
self.kwargs['step'] = None
self.kwargs['nuts_kwargs'] = nuts_kwargs
if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"):
self.kwargs['nuts_kwargs'] = nuts_kwargs
elif len(nuts_kwargs) > 0:
# add NUTS kwargs to standard kwargs
self.kwargs.update(nuts_kwargs)
with self.pymc3_model:
# perform the sampling
@@ -561,6 +607,10 @@ class Pymc3(MCMCSampler):
args = {}
return args, nuts_kwargs
def _pymc3_version(self):
pymc3, _, _ = self._import_external_sampler()
return pymc3.__version__
def set_prior(self):
"""
Set the PyMC3 prior distributions.
Loading