From 38cd6d54204aae7a09c89414840810f07a34e4b2 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Tue, 19 Feb 2019 05:12:41 -0600 Subject: [PATCH] fix the emcee writing to file for the prerelease --- bilby/core/sampler/emcee.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index b52c2bee0..3117a6c7a 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -51,6 +51,11 @@ class Emcee(MCMCSampler): use_ratio=False, plot=False, skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, burn_in_act=3, **kwargs): + import emcee + if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + self.prerelease = True + else: + self.prerelease = False MCMCSampler.__init__( self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, @@ -82,7 +87,6 @@ class Emcee(MCMCSampler): @property def sampler_function_kwargs(self): import emcee - keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal'] # updated function keywords for emcee > v2.2.1 @@ -93,7 +97,7 @@ class Emcee(MCMCSampler): function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} function_kwargs['p0'] = self.pos0 - if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + if self.prerelease: if function_kwargs['mh_proposal'] is not None: logger.warning("The 'mh_proposal' option is no longer used " "in emcee v{}, and will be ignored.".format(emcee.__version__)) @@ -109,8 +113,6 @@ class Emcee(MCMCSampler): @property def sampler_init_kwargs(self): - import emcee - init_kwargs = {key: value for key, value in self.kwargs.items() if key not in self.sampler_function_kwargs} @@ -122,7 +124,7 @@ class Emcee(MCMCSampler): updatekeys = {'dim': 'ndim', 'lnpostfn': 'log_prob_fn'} - if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'): + if self.prerelease: for key in updatekeys: if key in init_kwargs: init_kwargs[updatekeys[key]] = init_kwargs.pop(key) @@ -193,8 +195,10 @@ class Emcee(MCMCSampler): for sample in tqdm(sampler.sample(**self.sampler_function_kwargs), total=self.nsteps): - points = np.hstack([sample[0], np.array(sample[3])]) - # import IPython; IPython.embed() + if self.prerelease: + points = np.hstack([sample.coords, sample.blobs]) + else: + points = np.hstack([sample[0], np.array(sample[3])]) with open(out_file, "a") as ff: for ii, point in enumerate(points): ff.write(template.format(ii, *point)) -- GitLab