From 38cd6d54204aae7a09c89414840810f07a34e4b2 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 19 Feb 2019 05:12:41 -0600
Subject: [PATCH] fix the emcee writing to file for the prerelease

---
 bilby/core/sampler/emcee.py | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index b52c2bee0..3117a6c7a 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -51,6 +51,11 @@ class Emcee(MCMCSampler):
                  use_ratio=False, plot=False, skip_import_verification=False,
                  pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
                  burn_in_act=3, **kwargs):
+        import emcee
+        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+            self.prerelease = True
+        else:
+            self.prerelease = False
         MCMCSampler.__init__(
             self, likelihood=likelihood, priors=priors, outdir=outdir,
             label=label, use_ratio=use_ratio, plot=plot,
@@ -82,7 +87,6 @@ class Emcee(MCMCSampler):
     @property
     def sampler_function_kwargs(self):
         import emcee
-
         keys = ['lnprob0', 'rstate0', 'blobs0', 'iterations', 'thin', 'storechain', 'mh_proposal']
 
         # updated function keywords for emcee > v2.2.1
@@ -93,7 +97,7 @@ class Emcee(MCMCSampler):
         function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs}
         function_kwargs['p0'] = self.pos0
 
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+        if self.prerelease:
             if function_kwargs['mh_proposal'] is not None:
                 logger.warning("The 'mh_proposal' option is no longer used "
                                "in emcee v{}, and will be ignored.".format(emcee.__version__))
@@ -109,8 +113,6 @@ class Emcee(MCMCSampler):
 
     @property
     def sampler_init_kwargs(self):
-        import emcee
-
         init_kwargs = {key: value
                        for key, value in self.kwargs.items()
                        if key not in self.sampler_function_kwargs}
@@ -122,7 +124,7 @@ class Emcee(MCMCSampler):
         updatekeys = {'dim': 'ndim',
                       'lnpostfn': 'log_prob_fn'}
 
-        if LooseVersion(emcee.__version__) > LooseVersion('2.2.1'):
+        if self.prerelease:
             for key in updatekeys:
                 if key in init_kwargs:
                     init_kwargs[updatekeys[key]] = init_kwargs.pop(key)
@@ -193,8 +195,10 @@ class Emcee(MCMCSampler):
 
         for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
                            total=self.nsteps):
-            points = np.hstack([sample[0], np.array(sample[3])])
-            # import IPython; IPython.embed()
+            if self.prerelease:
+                points = np.hstack([sample.coords, sample.blobs])
+            else:
+                points = np.hstack([sample[0], np.array(sample[3])])
             with open(out_file, "a") as ff:
                 for ii, point in enumerate(points):
                     ff.write(template.format(ii, *point))
-- 
GitLab