Skip to content
Snippets Groups Projects
Commit 02f5818b authored by Gregory Ashton's avatar Gregory Ashton
Browse files

More bug fixing

parent d807cb4e
No related merge requests found
......@@ -201,6 +201,11 @@ class Emcee(MCMCSampler):
""" Returns the log-prior stored on disk """
return self.stored_chain['log_p']
def _init_chain_file(self):
with open(self.checkpoint_info.chain_file, "w+") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
@property
def checkpoint_info(self):
""" Defines various things related to checkpointing and storing data
......@@ -219,14 +224,8 @@ class Emcee(MCMCSampler):
self.label))
check_directory_exists_and_if_not_mkdir(out_dir)
sampler_file = os.path.join(out_dir, 'sampler.pickle')
# Initialise chain file
chain_file = os.path.join(out_dir, 'chain.dat')
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
sampler_file = os.path.join(out_dir, 'sampler.pickle')
chain_template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
......@@ -262,6 +261,7 @@ class Emcee(MCMCSampler):
def _initialise_sampler(self):
import emcee
self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._init_chain_file()
@property
def sampler(self):
......@@ -288,7 +288,8 @@ class Emcee(MCMCSampler):
def write_chains_to_file(self, sample):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
copyfile(chain_file, temp_chain_file)
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
......
......@@ -72,6 +72,7 @@ class Ptemcee(Emcee):
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
self._init_chain_file()
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
......@@ -80,8 +81,9 @@ class Ptemcee(Emcee):
def write_chains_to_file(self, pos, loglike, logpost):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
copyfile(chain_file, temp_chain_file)
with open(temp_chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment