Skip to content
Snippets Groups Projects
Commit 9bffa289 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Overhaul to using pickles and signals

parent 57a23737
Branches release/2.0.x
No related tags found
1 merge request!423Improvements to checkpointing for emcee/ptemcee
from __future__ import absolute_import, print_function
from collections import namedtuple
import os
import signal
import sys
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
import dill as pickle
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
......@@ -66,6 +70,9 @@ class Emcee(MCMCSampler):
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
signal.signal(signal.SIGINT, self.checkpoint_and_exit)
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
......@@ -165,66 +172,139 @@ class Emcee(MCMCSampler):
def nsteps(self, nsteps):
self.kwargs['iterations'] = nsteps
def __getstate__(self):
# In order to be picklable with dill, we need to discard the pool
# object before trying.
d = self.__dict__
d["_Sampler__kwargs"]["pool"] = None
return d
@property
def stored_chain(self):
""" Read the stored zero-temperature chain data in from disk """
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def set_up_checkpoint(self):
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
out_file = os.path.join(out_dir, 'chain.dat')
@property
def stored_samples(self):
""" Returns the samples stored on disk """
return self.stored_chain[self.search_parameter_keys]
if self.resume:
self.load_old_chain(out_file)
else:
self._set_pos0()
@property
def stored_loglike(self):
""" Returns the log-likelihood stored on disk """
return self.stored_chain['log_l']
@property
def stored_logprior(self):
""" Returns the log-prior stored on disk """
return self.stored_chain['log_p']
@property
def checkpoint_info(self):
""" Defines various things related to checkpointing and storing data
Returns
-------
checkpoint_info: named_tuple
An object with attributes `sampler_file`, `chain_file`, and
`chain_template`. The first two give paths to where the sampler and
chain data is stored, the last a formatted-str-template with which
to write the chain data to disk
"""
out_dir = os.path.join(
self.outdir, '{}_{}'.format(self.__class__.__name__, self.label))
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l\n'.format(
sampler_file = os.path.join(out_dir, 'sampler.pickle')
# Initialise chain file
chain_file = os.path.join(out_dir, 'chain.dat')
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
chain_template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
return out_file, template
CheckpointInfo = namedtuple(
'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
def run_sampler(self):
checkpoint_info = CheckpointInfo(
sampler_file=sampler_file, chain_file=chain_file,
chain_template=chain_template)
return checkpoint_info
@property
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :nsteps, :]
def checkpoint(self):
""" Writes a pickle file of the sampler to disk using dill """
logger.info("Checkpointing sampler to file {}"
.format(self.checkpoint_info.sampler_file))
with open(self.checkpoint_info.sampler_file, 'wb') as f:
# Overwrites the stored sampler chain with one that is truncated
# to only the completed steps
self.sampler._chain = self.sampler_chain
pickle.dump(self._sampler, f)
def checkpoint_and_exit(self, signum, frame):
logger.info("Recieved signal {}".format(signum))
self.checkpoint()
sys.exit()
def _initialise_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
out_file, template = self.set_up_checkpoint()
self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
def _set_pos0_for_resume(self):
self.pos0 = self.sampler.chain[:, -1, :]
@property
def sampler(self):
""" Returns the ptemcee sampler object
If, alrady initialized, returns the stored _sampler value. Otherwise,
first checks if there is a pickle file from which to load. If there is
not, then initialize the sampler and set the initial random draw
"""
if hasattr(self, '_sampler'):
pass
elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
with open(self.checkpoint_info.sampler_file, 'rb') as f:
self._sampler = pickle.load(f)
self._set_pos0_for_resume()
else:
self._initialise_sampler()
self._set_pos0()
return self._sampler
def write_chains_to_file(self, sample):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(self.checkpoint_info.chain_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(self.checkpoint_info.chain_template.format(ii, *point))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
print('pos0', self.pos0)
sampler_function_kwargs['p0'] = self.pos0
for sample in tqdm(
sampler.sample(iterations=iterations, **sampler_function_kwargs),
self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.write_chains_to_file(sample)
self.result.sampler_output = np.nan
blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None:
chain = np.vstack([self._old_chain[:, :-2],
sampler.chain.reshape((-1, self.ndim))])
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
chain = self.sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
......@@ -236,13 +316,27 @@ class Emcee(MCMCSampler):
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = sampler.chain
self.result.walkers = self.sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
return self.result
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return len(self.sampler.blobs)
def _draw_pos0_from_prior(self):
return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
return np.array(
[self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
@property
def _pos0_shape(self):
return (self.nwalkers, self.ndim)
def _set_pos0(self):
if self.pos0 is not None:
......@@ -250,9 +344,9 @@ class Emcee(MCMCSampler):
if isinstance(self.pos0, DataFrame):
self.pos0 = self.pos0[self.search_parameter_keys].values
elif type(self.pos0) in (list, np.ndarray):
self.pos0 = np.squeeze(self.kwargs['pos0'])
self.pos0 = np.squeeze(self.pos0)
if self.pos0.shape != (self.nwalkers, self.ndim):
if self.pos0.shape != self._pos0_shape:
raise ValueError(
'Input pos0 should be of shape ndim, nwalkers')
logger.debug("Checking input pos0")
......@@ -262,51 +356,6 @@ class Emcee(MCMCSampler):
logger.debug("Generating initial walker positions from prior")
self.pos0 = self._draw_pos0_from_prior()
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx, :]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property
def _previous_iterations(self):
if self._old_chain is None:
return 0
try:
return self._old_chain.shape[0] // self.nwalkers
except AttributeError:
logger.warning(
"Unable to calculate previous iterations from checkpoint,"
" defaulting to zero")
return 0
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
try:
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
except Exception:
logger.warning('Failed to resume. Corrupt checkpoint file {}.'
.format(file_name))
self._set_pos0()
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
......
from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from ..utils import logger, get_progress_bar
from . import Emcee
from .base_sampler import SamplerError
......@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee
"""
default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None,
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[],
loglkwargs={}, logpkwargs={}, adaptation_lag=10000,
adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False,
)
default_kwargs = dict(
ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs}
@property
def checkpoint_info(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def ntemps(self):
return self.kwargs['ntemps']
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
def _set_pos0_for_resume(self):
self.pos0 = None
@property
def stored_chain(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
@property
def stored_samples(self):
return self.stored_chain[self.search_parameter_keys]
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
@property
def stored_loglike(self):
return self.stored_chain['log_l']
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property
def stored_logprior(self):
return self.stored_chain['log_p']
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def run_sampler(self):
def _initialise_sampler(self):
import ptemcee
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
if self.resume:
self.load_old_chain()
else:
self._set_pos0()
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
self.sampler.tswap_acceptance_fraction))
def write_chains_to_file(self, pos, loglike, logpost):
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :])
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.write_chains_to_file(pos, loglike, logpost)
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
self.result.nburn = self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
self.result.walkers = walkers
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas
self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps)
self.sampler.log_evidence_estimate(
self.sampler.loglikelihood, self.nburn / self.nsteps)
return self.result
......@@ -6,3 +6,4 @@ matplotlib>=2.0
scipy>=0.16
pandas
mock
dill
......@@ -79,6 +79,7 @@ setup(name='bilby',
'future',
'dynesty',
'corner',
'dill',
'numpy>=1.9',
'matplotlib>=2.0',
'pandas',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment