diff --git a/bilby/core/sampler/pymc3.py b/bilby/core/sampler/pymc3.py index 824f294ce47349582aa3ec702cc7ff0ff1639781..73f5f3edb13aa9715a67b11b0a3a945667dbfefe 100644 --- a/bilby/core/sampler/pymc3.py +++ b/bilby/core/sampler/pymc3.py @@ -1,6 +1,7 @@ from __future__ import absolute_import, print_function from collections import OrderedDict +from distutils.version import StrictVersion import numpy as np @@ -45,8 +46,6 @@ class Pymc3(MCMCSampler): 'CategoricalGibbsMetropolis'. Note: you cannot provide a PyMC3 step method function itself here as it is outside of the model context manager. - nuts_kwargs: dict - Keyword arguments for the NUTS sampler. step_kwargs: dict Options for steps methods other than NUTS. The dictionary is keyed on lowercase step method names with values being dictionaries of keywords @@ -56,13 +55,27 @@ class Pymc3(MCMCSampler): default_kwargs = dict( draws=500, step=None, init='auto', n_init=200000, start=None, trace=None, chain_idx=0, - chains=2, cores=1, tune=500, nuts_kwargs=None, step_kwargs=None, progressbar=True, - model=None, random_seed=None, discard_tuned_samples=True, - compute_convergence_checks=True) + chains=2, cores=1, tune=500, progressbar=True, model=None, random_seed=None, + discard_tuned_samples=True, compute_convergence_checks=True, nuts_kwargs=None, + step_kwargs=None, + ) + + default_nuts_kwargs = dict( + target_accept=None, max_treedepth=None, step_scale=None, Emax=None, + gamma=None, k=None, t0=None, adapt_step_size=None, early_max_treedepth=None, + scaling=None, is_cov=None, potential=None, + ) + + default_kwargs.update(default_nuts_kwargs) def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, **kwargs): + # add default step kwargs + _, STEP_METHODS, _ = self._import_external_sampler() + self.default_step_kwargs = {m.__name__.lower(): None for m in STEP_METHODS} + self.default_kwargs.update(self.default_step_kwargs) + super(Pymc3, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, **kwargs) @@ -454,8 +467,35 @@ class Pymc3(MCMCSampler): self.set_likelihood() # get the step method keyword arguments - step_kwargs = self.kwargs.pop('step_kwargs') - nuts_kwargs = self.kwargs.pop('nuts_kwargs') + step_kwargs = self.kwargs.pop("step_kwargs") + if step_kwargs is not None: + # remove all individual default step kwargs if passed together using + # step_kwargs keywords + for key in self.default_step_kwargs: + self.kwargs.pop(key) + else: + # remove any None default step keywords and place others in step_kwargs + step_kwargs = {} + for key in self.default_step_kwargs: + if self.kwargs[key] is None: + self.kwargs.pop(key) + else: + step_kwargs[key] = self.kwargs.pop(key) + + nuts_kwargs = self.kwargs.pop("nuts_kwargs") + if nuts_kwargs is not None: + # remove all individual default nuts kwargs if passed together using + # nuts_kwargs keywords + for key in self.default_nuts_kwargs: + self.kwargs.pop(key) + else: + # remove any None default nuts keywords and place others in nut_kwargs + nuts_kwargs = {} + for key in self.default_nuts_kwargs: + if self.kwargs[key] is None: + self.kwargs.pop(key) + else: + nuts_kwargs[key] = self.kwargs.pop(key) methodslist = [] # set the step method @@ -496,13 +536,19 @@ class Pymc3(MCMCSampler): self.kwargs['step'] = pymc3.__dict__[step_methods[curmethod]](**args) else: # re-add step_kwargs if no step methods are set - self.kwargs['step_kwargs'] = step_kwargs + if len(step_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"): + self.kwargs['step_kwargs'] = step_kwargs # check whether only NUTS step method has been assigned if np.all([sm.lower() == 'nuts' for sm in methodslist]): # in this case we can let PyMC3 autoinitialise NUTS, so remove the step methods and re-add nuts_kwargs self.kwargs['step'] = None - self.kwargs['nuts_kwargs'] = nuts_kwargs + + if len(nuts_kwargs) > 0 and StrictVersion(pymc3.__version__) < StrictVersion("3.7"): + self.kwargs['nuts_kwargs'] = nuts_kwargs + elif len(nuts_kwargs) > 0: + # add NUTS kwargs to standard kwargs + self.kwargs.update(nuts_kwargs) with self.pymc3_model: # perform the sampling @@ -561,6 +607,10 @@ class Pymc3(MCMCSampler): args = {} return args, nuts_kwargs + def _pymc3_version(self): + pymc3, _, _ = self._import_external_sampler() + return pymc3.__version__ + def set_prior(self): """ Set the PyMC3 prior distributions. diff --git a/test/sampler_test.py b/test/sampler_test.py index c3c5e4d415e16a2d9f662b47be5b9d1c20bc8538..f03fc97d00265993dd1375825139e213add14ad7 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -695,14 +695,16 @@ class TestPyMC3(unittest.TestCase): chains=2, cores=1, tune=500, - nuts_kwargs=None, - step_kwargs=None, progressbar=True, model=None, + nuts_kwargs=None, + step_kwargs=None, random_seed=None, discard_tuned_samples=True, compute_convergence_checks=True, ) + expected.update(self.sampler.default_nuts_kwargs) + expected.update(self.sampler.default_step_kwargs) self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): @@ -717,14 +719,16 @@ class TestPyMC3(unittest.TestCase): chains=2, cores=1, tune=500, - nuts_kwargs=None, - step_kwargs=None, progressbar=True, model=None, + nuts_kwargs=None, + step_kwargs=None, random_seed=None, discard_tuned_samples=True, compute_convergence_checks=True, ) + expected.update(self.sampler.default_nuts_kwargs) + expected.update(self.sampler.default_step_kwargs) self.sampler.kwargs["draws"] = 123 for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy()