From 02f5818b38805e10f04d57720e328eaa056e3b20 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 3 Apr 2019 16:26:12 +1100 Subject: [PATCH] More bug fixing --- bilby/core/sampler/emcee.py | 17 +++++++++-------- bilby/core/sampler/ptemcee.py | 4 +++- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 355e93ef1..6b4a0e732 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -201,6 +201,11 @@ class Emcee(MCMCSampler): """ Returns the log-prior stored on disk """ return self.stored_chain['log_p'] + def _init_chain_file(self): + with open(self.checkpoint_info.chain_file, "w+") as ff: + ff.write('walker\t{}\tlog_l\tlog_p\n'.format( + '\t'.join(self.search_parameter_keys))) + @property def checkpoint_info(self): """ Defines various things related to checkpointing and storing data @@ -219,14 +224,8 @@ class Emcee(MCMCSampler): self.label)) check_directory_exists_and_if_not_mkdir(out_dir) - sampler_file = os.path.join(out_dir, 'sampler.pickle') - - # Initialise chain file chain_file = os.path.join(out_dir, 'chain.dat') - if not os.path.isfile(chain_file): - with open(chain_file, "w") as ff: - ff.write('walker\t{}\tlog_l\tlog_p\n'.format( - '\t'.join(self.search_parameter_keys))) + sampler_file = os.path.join(out_dir, 'sampler.pickle') chain_template =\ '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' @@ -262,6 +261,7 @@ class Emcee(MCMCSampler): def _initialise_sampler(self): import emcee self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) + self._init_chain_file() @property def sampler(self): @@ -288,7 +288,8 @@ class Emcee(MCMCSampler): def write_chains_to_file(self, sample): chain_file = self.checkpoint_info.chain_file temp_chain_file = chain_file + '.temp' - copyfile(chain_file, temp_chain_file) + if os.path.isfile(chain_file): + copyfile(chain_file, temp_chain_file) if self.prerelease: points = np.hstack([sample.coords, sample.blobs]) diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index b6b073110..31c5609a6 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -72,6 +72,7 @@ class Ptemcee(Emcee): self._sampler = ptemcee.Sampler( dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, **self.sampler_init_kwargs) + self._init_chain_file() def print_tswap_acceptance_fraction(self): logger.info("Sampler per-chain tswap acceptance fraction = {}".format( @@ -80,8 +81,9 @@ class Ptemcee(Emcee): def write_chains_to_file(self, pos, loglike, logpost): chain_file = self.checkpoint_info.chain_file temp_chain_file = chain_file + '.temp' + if os.path.isfile(chain_file): + copyfile(chain_file, temp_chain_file) - copyfile(chain_file, temp_chain_file) with open(temp_chain_file, "a") as ff: loglike = np.squeeze(loglike[0, :]) logprior = np.squeeze(logpost[0, :]) - loglike -- GitLab