Maintenance will be performed on git.ligo.org, chat.ligo.org, and docs.ligo.org, starting at approximately 10am CDT Tuesday 20 August 2019. The maintenance is expected to take around an hour and here will be two short periods of downtime, one at the beginning of the maintenance and another at the end.

Commit 02f5818b authored by Gregory Ashton's avatar Gregory Ashton

More bug fixing

parent d807cb4e
Pipeline #56033 passed with stage
in 7 minutes and 42 seconds
......@@ -201,6 +201,11 @@ class Emcee(MCMCSampler):
""" Returns the log-prior stored on disk """
return self.stored_chain['log_p']
def _init_chain_file(self):
with open(self.checkpoint_info.chain_file, "w+") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
@property
def checkpoint_info(self):
""" Defines various things related to checkpointing and storing data
......@@ -219,14 +224,8 @@ class Emcee(MCMCSampler):
self.label))
check_directory_exists_and_if_not_mkdir(out_dir)
sampler_file = os.path.join(out_dir, 'sampler.pickle')
# Initialise chain file
chain_file = os.path.join(out_dir, 'chain.dat')
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
sampler_file = os.path.join(out_dir, 'sampler.pickle')
chain_template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
......@@ -262,6 +261,7 @@ class Emcee(MCMCSampler):
def _initialise_sampler(self):
import emcee
self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._init_chain_file()
@property
def sampler(self):
......@@ -288,7 +288,8 @@ class Emcee(MCMCSampler):
def write_chains_to_file(self, sample):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
copyfile(chain_file, temp_chain_file)
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
......
......@@ -72,6 +72,7 @@ class Ptemcee(Emcee):
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
self._init_chain_file()
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
......@@ -80,8 +81,9 @@ class Ptemcee(Emcee):
def write_chains_to_file(self, pos, loglike, logpost):
chain_file = self.checkpoint_info.chain_file
temp_chain_file = chain_file + '.temp'
if os.path.isfile(chain_file):
copyfile(chain_file, temp_chain_file)
copyfile(chain_file, temp_chain_file)
with open(temp_chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment