From 02f5818b38805e10f04d57720e328eaa056e3b20 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 3 Apr 2019 16:26:12 +1100
Subject: [PATCH] More bug fixing

---
 bilby/core/sampler/emcee.py   | 17 +++++++++--------
 bilby/core/sampler/ptemcee.py |  4 +++-
 2 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py
index 355e93ef1..6b4a0e732 100644
--- a/bilby/core/sampler/emcee.py
+++ b/bilby/core/sampler/emcee.py
@@ -201,6 +201,11 @@ class Emcee(MCMCSampler):
         """ Returns the log-prior stored on disk """
         return self.stored_chain['log_p']
 
+    def _init_chain_file(self):
+        with open(self.checkpoint_info.chain_file, "w+") as ff:
+            ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
+                '\t'.join(self.search_parameter_keys)))
+
     @property
     def checkpoint_info(self):
         """ Defines various things related to checkpointing and storing data
@@ -219,14 +224,8 @@ class Emcee(MCMCSampler):
                                         self.label))
         check_directory_exists_and_if_not_mkdir(out_dir)
 
-        sampler_file = os.path.join(out_dir, 'sampler.pickle')
-
-        # Initialise chain file
         chain_file = os.path.join(out_dir, 'chain.dat')
-        if not os.path.isfile(chain_file):
-            with open(chain_file, "w") as ff:
-                ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
-                    '\t'.join(self.search_parameter_keys)))
+        sampler_file = os.path.join(out_dir, 'sampler.pickle')
         chain_template =\
             '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
 
@@ -262,6 +261,7 @@ class Emcee(MCMCSampler):
     def _initialise_sampler(self):
         import emcee
         self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
+        self._init_chain_file()
 
     @property
     def sampler(self):
@@ -288,7 +288,8 @@ class Emcee(MCMCSampler):
     def write_chains_to_file(self, sample):
         chain_file = self.checkpoint_info.chain_file
         temp_chain_file = chain_file + '.temp'
-        copyfile(chain_file, temp_chain_file)
+        if os.path.isfile(chain_file):
+            copyfile(chain_file, temp_chain_file)
 
         if self.prerelease:
             points = np.hstack([sample.coords, sample.blobs])
diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py
index b6b073110..31c5609a6 100644
--- a/bilby/core/sampler/ptemcee.py
+++ b/bilby/core/sampler/ptemcee.py
@@ -72,6 +72,7 @@ class Ptemcee(Emcee):
         self._sampler = ptemcee.Sampler(
             dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
             **self.sampler_init_kwargs)
+        self._init_chain_file()
 
     def print_tswap_acceptance_fraction(self):
         logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
@@ -80,8 +81,9 @@ class Ptemcee(Emcee):
     def write_chains_to_file(self, pos, loglike, logpost):
         chain_file = self.checkpoint_info.chain_file
         temp_chain_file = chain_file + '.temp'
+        if os.path.isfile(chain_file):
+            copyfile(chain_file, temp_chain_file)
 
-        copyfile(chain_file, temp_chain_file)
         with open(temp_chain_file, "a") as ff:
             loglike = np.squeeze(loglike[0, :])
             logprior = np.squeeze(logpost[0, :]) - loglike
-- 
GitLab