Maintenance will be preformed on git.ligo.org, chat.ligo.org, containers.ligo.org, and docs.ligo.org tomorrow, 25 February 2020, starting at 10am CST. There will be a short, around 5 minute, period of downtime towards the end of the maintenance window. In addition the runners will be paused around 9am CST and resumed at the end of the maintenance.

Commit 9bffa289 authored by Gregory Ashton's avatar Gregory Ashton

Overhaul to using pickles and signals

parent 57a23737
This diff is collapsed.
from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from ..utils import logger, get_progress_bar
from . import Emcee
from .base_sampler import SamplerError
......@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee
"""
default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None,
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[],
loglkwargs={}, logpkwargs={}, adaptation_lag=10000,
adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False,
)
default_kwargs = dict(
ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs}
@property
def checkpoint_info(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def ntemps(self):
return self.kwargs['ntemps']
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
def _set_pos0_for_resume(self):
self.pos0 = None
@property
def stored_chain(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
@property
def stored_samples(self):
return self.stored_chain[self.search_parameter_keys]
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
@property
def stored_loglike(self):
return self.stored_chain['log_l']
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property
def stored_logprior(self):
return self.stored_chain['log_p']
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def run_sampler(self):
def _initialise_sampler(self):
import ptemcee
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
if self.resume:
self.load_old_chain()
else:
self._set_pos0()
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
self.sampler.tswap_acceptance_fraction))
def write_chains_to_file(self, pos, loglike, logpost):
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :])
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.write_chains_to_file(pos, loglike, logpost)
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
self.result.nburn = self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
self.result.walkers = walkers
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas
self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps)
self.sampler.log_evidence_estimate(
self.sampler.loglikelihood, self.nburn / self.nsteps)
return self.result
......@@ -6,3 +6,4 @@ matplotlib>=2.0
scipy>=0.16
pandas
mock
dill
......@@ -79,6 +79,7 @@ setup(name='bilby',
'future',
'dynesty',
'corner',
'dill',
'numpy>=1.9',
'matplotlib>=2.0',
'pandas',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment