Skip to content
Snippets Groups Projects
Commit 38cd6d54 authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

fix the emcee writing to file for the prerelease

parent 3e1a47e3
No related branches found
No related tags found
No related merge requests found
......@@ -51,6 +51,11 @@ class Emcee(MCMCSampler):
use_ratio=False, plot=False, skip_import_verification=False,
pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
burn_in_act=3, **kwargs):
import emcee
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
self.prerelease = True
else:
self.prerelease = False
MCMCSampler.__init__(
self, likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
......@@ -82,7 +87,6 @@ class Emcee(MCMCSampler):
@property
def sampler_function_kwargs(self):
import emcee
keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal']
# updated function keywords for emcee > v2.2.1
......@@ -93,7 +97,7 @@ class Emcee(MCMCSampler):
function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
function_kwargs['p0'] = self.pos0
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
if self.prerelease:
if function_kwargs['mh_proposal'] is not None:
logger.warning("The 'mh_proposal' option is no longer used "
"in emcee v{}, and will be ignored.".format(emcee.__version__))
......@@ -109,8 +113,6 @@ class Emcee(MCMCSampler):
@property
def sampler_init_kwargs(self):
import emcee
init_kwargs = {key: value
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
......@@ -122,7 +124,7 @@ class Emcee(MCMCSampler):
updatekeys = {'dim': 'ndim',
'lnpostfn': 'log_prob_fn'}
if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
if self.prerelease:
for key in updatekeys:
if key in init_kwargs:
init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
......@@ -193,8 +195,10 @@ class Emcee(MCMCSampler):
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
points = np.hstack([sample[0], np.array(sample[3])])
# import IPython; IPython.embed()
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment