Skip to content
Snippets Groups Projects

Allow emcee to work with pre-release versions

Merged Matthew David Pitkin requested to merge matthew-pitkin/bilby:emcee_fix into master
@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
from ..utils import logger, get_progress_bar
from .base_sampler import MCMCSampler, SamplerError
@@ -65,14 +66,58 @@ class Emcee(MCMCSampler):
@property
def sampler_function_kwargs(self):
import emcee
keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal']
return {key: self.kwargs[key] for key in keys}
# updated function keywords for emcee > v2.2.1
updatekeys = {'p0': 'initial_state',
'lnprob0': 'log_prob0',
'storechain': 'store'}
function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
function_kwargs['p0'] = self.pos0
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
if function_kwargs['mh_proposal'] is not None:
logger.warning("The 'mh_proposal' option is no longer used "
"in emcee v{}, and will be ignored.".format(emcee.__version__))
del function_kwargs['mh_proposal']
for key in updatekeys:
if updatekeys[key] not in function_kwargs:
function_kwargs[updatekeys[key]] = function_kwargs.pop(key)
else:
del function_kwargs[key]
return function_kwargs
@property
def sampler_init_kwargs(self):
return {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
import emcee
init_kwargs = {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
init_kwargs['lnpostfn'] = self.lnpostfn
init_kwargs['dim'] = self.ndim
# updated init keywords for emcee > v2.2.1
updatekeys = {'dim': 'ndim',
'lnpostfn': 'log_prob_fn'}
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
for key in updatekeys:
if key in init_kwargs:
init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal']
for key in oldfunckeys:
if key in init_kwargs:
del init_kwargs[key]
return init_kwargs
@property
def nburn(self):
@@ -107,9 +152,9 @@ class Emcee(MCMCSampler):
def run_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(dim=self.ndim, lnpostfn=self.lnpostfn, **self.sampler_init_kwargs)
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._set_pos0()
for _ in tqdm(sampler.sample(p0=self.pos0, **self.sampler_function_kwargs),
for _ in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
pass
self.result.sampler_output = np.nan
Loading