Skip to content
Snippets Groups Projects
Commit c4eb8db7 authored by Matthew Pitkin's avatar Matthew Pitkin
Browse files

emcee.py: some re-factoring to try and get the tests to work

parent 087e0147
No related branches found
No related tags found
1 merge request!340Allow emcee to work with pre-release versions
......@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
from ..utils import logger, get_progress_bar
from .base_sampler import MCMCSampler, SamplerError
......@@ -51,23 +52,9 @@ class Emcee(MCMCSampler):
**kwargs)
self.pos0 = pos0
self.nburn = nburn
self.nwalkers = self.kwargs.pop('nwalkers')
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
# a fix for versions of emcee newer than 2.2.1
if 'lnprob0' in self.kwargs:
from distutils.version import LooseVersion
import emcee
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
self.kwargs['log_prob0'] = self.kwargs.pop('lnprob0')
self.kwargs['store'] = self.kwargs.pop('storechain')
if self.kwargs['mh_proposal'] is not None:
logger.warning("The 'mh_proposal' option is no longer used "
"in emcee v{}, and will be ignored.".format(emcee.__version__))
del self.kwargs['mh_proposal']
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
......@@ -79,14 +66,56 @@ class Emcee(MCMCSampler):
@property
def sampler_function_kwargs(self):
keys = ['lnprob0', 'log_prob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'store', 'mh_proposal']
return {key: self.kwargs[key] for key in keys if key in self.kwargs}
import emcee
keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal']
# updated function keywords for emcee > v2.2.1
updatekeys = {'p0': 'initial_state',
'lnprob0': 'log_prob0',
'storechain': 'store'}
function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
function_kwargs['p0'] = self.pos0
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
for key in updatekeys:
function_kwargs[updatekeys[key]] = function_kwargs.pop(key)
if function_kwargs['mh_proposal'] is not None:
logger.warning("The 'mh_proposal' option is no longer used "
"in emcee v{}, and will be ignored.".format(emcee.__version__))
del function_kwargs['mh_proposal']
return function_kwargs
@property
def sampler_init_kwargs(self):
return {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
import emcee
init_kwargs = {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
init_kwargs['lnpostfn'] = self.lnpostfn
init_kwargs['dim'] = self.ndim
# updated init keywords for emcee > v2.2.1
updatekeys = {'dim': 'ndim',
'lnpostfn': 'log_prob_fn'}
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
for key in updatekeys:
if key in init_kwargs:
init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
oldfunckeys = ['p0', 'lnprob0', 'storechain', 'mh_proposal']
for key in oldfunckeys:
if key in init_kwargs:
del init_kwargs[key]
return init_kwargs
@property
def nburn(self):
......@@ -108,12 +137,8 @@ class Emcee(MCMCSampler):
@property
def nwalkers(self):
return self.__nwalkers
@nwalkers.setter
def nwalkers(self, nwalkers):
self.__nwalkers = nwalkers
return self.kwargs['nwalkers']
@property
def nsteps(self):
return self.kwargs['iterations']
......@@ -125,9 +150,9 @@ class Emcee(MCMCSampler):
def run_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(self.nwalkers, self.ndim, self.lnpostfn, **self.sampler_init_kwargs)
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._set_pos0()
for _ in tqdm(sampler.sample(self.pos0, **self.sampler_function_kwargs),
for _ in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
pass
self.result.sampler_output = np.nan
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment