diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index b74314e000e248b28eee3f69f2372fc648c45359..56e6887d188ccb78e16e614799f74cd324a7e8b9 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function import numpy as np from pandas import DataFrame +from distutils.version import LooseVersion from ..utils import logger, get_progress_bar from .base_sampler import MCMCSampler, SamplerError @@ -73,14 +74,58 @@ class Emcee(MCMCSampler): @property def sampler_function_kwargs(self): + import emcee + keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal'] - return {key: self.kwargs[key] for key in keys} + + # updated function keywords for emcee > v2.2.1 + updatekeys = {'p0': 'initial_state', + 'lnprob0': 'log_prob0', + 'storechain': 'store'} + + function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + function_kwargs['p0'] = self.pos0 + + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + if function_kwargs['mh_proposal'] is not None: + logger.warning("The 'mh_proposal' option is no longer used " + "in emcee v{}, and will be ignored.".format(emcee.__version__)) + del function_kwargs['mh_proposal'] + + for key in updatekeys: + if updatekeys[key] not in function_kwargs: + function_kwargs[updatekeys[key]] = function_kwargs.pop(key) + else: + del function_kwargs[key] + + return function_kwargs @property def sampler_init_kwargs(self): - return {key: value - for key, value in self.kwargs.items() - if key not in self.sampler_function_kwargs} + import emcee + + init_kwargs = {key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs} + + init_kwargs['lnpostfn'] = self.lnpostfn + init_kwargs['dim'] = self.ndim + + # updated init keywords for emcee > v2.2.1 + updatekeys = {'dim': 'ndim', + 'lnpostfn': 'log_prob_fn'} + + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + for key in updatekeys: + if key in init_kwargs: + init_kwargs[updatekeys[key]] = init_kwargs.pop(key) + + oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal'] + for key in oldfunckeys: + if key in init_kwargs: + del init_kwargs[key] + + return init_kwargs @property def nburn(self): @@ -122,9 +167,9 @@ class Emcee(MCMCSampler): def run_sampler(self): import emcee tqdm = get_progress_bar() - sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs) + sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) self._set_pos0() - for _ in tqdm(sampler.sample(p0=self.pos0, **self.sampler_function_kwargs), + for _ in tqdm(sampler.sample(**self.sampler_function_kwargs), total=self.nsteps): pass self.result.sampler_output = np.nan