Skip to content
Snippets Groups Projects
Commit 9bffa289 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Overhaul to using pickles and signals

parent 57a23737
No related branches found
No related tags found
1 merge request!423Improvements to checkpointing for emcee/ptemcee
from __future__ import absolute_import, print_function
from collections import namedtuple
import os
import signal
import sys
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
import dill as pickle
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
......@@ -66,6 +70,9 @@ class Emcee(MCMCSampler):
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
signal.signal(signal.SIGINT, self.checkpoint_and_exit)
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
for equiv in self.nwalkers_equiv_kwargs:
......@@ -165,66 +172,139 @@ class Emcee(MCMCSampler):
def nsteps(self, nsteps):
self.kwargs['iterations'] = nsteps
def __getstate__(self):
# In order to be picklable with dill, we need to discard the pool
# object before trying.
d = self.__dict__
d["_Sampler__kwargs"]["pool"] = None
return d
@property
def stored_chain(self):
""" Read the stored zero-temperature chain data in from disk """
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def set_up_checkpoint(self):
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
out_file = os.path.join(out_dir, 'chain.dat')
@property
def stored_samples(self):
""" Returns the samples stored on disk """
return self.stored_chain[self.search_parameter_keys]
if self.resume:
self.load_old_chain(out_file)
else:
self._set_pos0()
@property
def stored_loglike(self):
""" Returns the log-likelihood stored on disk """
return self.stored_chain['log_l']
@property
def stored_logprior(self):
""" Returns the log-prior stored on disk """
return self.stored_chain['log_p']
@property
def checkpoint_info(self):
""" Defines various things related to checkpointing and storing data
Returns
-------
checkpoint_info: named_tuple
An object with attributes `sampler_file`, `chain_file`, and
`chain_template`. The first two give paths to where the sampler and
chain data is stored, the last a formatted-str-template with which
to write the chain data to disk
"""
out_dir = os.path.join(
self.outdir, '{}_{}'.format(self.__class__.__name__, self.label))
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l\n'.format(
sampler_file = os.path.join(out_dir, 'sampler.pickle')
# Initialise chain file
chain_file = os.path.join(out_dir, 'chain.dat')
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
chain_template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
return out_file, template
CheckpointInfo = namedtuple(
'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
def run_sampler(self):
checkpoint_info = CheckpointInfo(
sampler_file=sampler_file, chain_file=chain_file,
chain_template=chain_template)
return checkpoint_info
@property
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :nsteps, :]
def checkpoint(self):
""" Writes a pickle file of the sampler to disk using dill """
logger.info("Checkpointing sampler to file {}"
.format(self.checkpoint_info.sampler_file))
with open(self.checkpoint_info.sampler_file, 'wb') as f:
# Overwrites the stored sampler chain with one that is truncated
# to only the completed steps
self.sampler._chain = self.sampler_chain
pickle.dump(self._sampler, f)
def checkpoint_and_exit(self, signum, frame):
logger.info("Recieved signal {}".format(signum))
self.checkpoint()
sys.exit()
def _initialise_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
out_file, template = self.set_up_checkpoint()
self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
def _set_pos0_for_resume(self):
self.pos0 = self.sampler.chain[:, -1, :]
@property
def sampler(self):
""" Returns the ptemcee sampler object
If, alrady initialized, returns the stored _sampler value. Otherwise,
first checks if there is a pickle file from which to load. If there is
not, then initialize the sampler and set the initial random draw
"""
if hasattr(self, '_sampler'):
pass
elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
with open(self.checkpoint_info.sampler_file, 'rb') as f:
self._sampler = pickle.load(f)
self._set_pos0_for_resume()
else:
self._initialise_sampler()
self._set_pos0()
return self._sampler
def write_chains_to_file(self, sample):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(self.checkpoint_info.chain_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(self.checkpoint_info.chain_template.format(ii, *point))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
print('pos0', self.pos0)
sampler_function_kwargs['p0'] = self.pos0
for sample in tqdm(
sampler.sample(iterations=iterations, **sampler_function_kwargs),
self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.write_chains_to_file(sample)
self.result.sampler_output = np.nan
blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None:
chain = np.vstack([self._old_chain[:, :-2],
sampler.chain.reshape((-1, self.ndim))])
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
chain = self.sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
......@@ -236,13 +316,27 @@ class Emcee(MCMCSampler):
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = sampler.chain
self.result.walkers = self.sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
return self.result
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return len(self.sampler.blobs)
def _draw_pos0_from_prior(self):
return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
return np.array(
[self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
@property
def _pos0_shape(self):
return (self.nwalkers, self.ndim)
def _set_pos0(self):
if self.pos0 is not None:
......@@ -250,9 +344,9 @@ class Emcee(MCMCSampler):
if isinstance(self.pos0, DataFrame):
self.pos0 = self.pos0[self.search_parameter_keys].values
elif type(self.pos0) in (list, np.ndarray):
self.pos0 = np.squeeze(self.kwargs['pos0'])
self.pos0 = np.squeeze(self.pos0)
if self.pos0.shape != (self.nwalkers, self.ndim):
if self.pos0.shape != self._pos0_shape:
raise ValueError(
'Input pos0 should be of shape ndim, nwalkers')
logger.debug("Checking input pos0")
......@@ -262,51 +356,6 @@ class Emcee(MCMCSampler):
logger.debug("Generating initial walker positions from prior")
self.pos0 = self._draw_pos0_from_prior()
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx, :]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property
def _previous_iterations(self):
if self._old_chain is None:
return 0
try:
return self._old_chain.shape[0] // self.nwalkers
except AttributeError:
logger.warning(
"Unable to calculate previous iterations from checkpoint,"
" defaulting to zero")
return 0
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
try:
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
except Exception:
logger.warning('Failed to resume. Corrupt checkpoint file {}.'
.format(file_name))
self._set_pos0()
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
......
from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from ..utils import logger, get_progress_bar
from . import Emcee
from .base_sampler import SamplerError
......@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee
"""
default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None,
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[],
loglkwargs={}, logpkwargs={}, adaptation_lag=10000,
adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False,
)
default_kwargs = dict(
ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs}
@property
def checkpoint_info(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def ntemps(self):
return self.kwargs['ntemps']
def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
def _set_pos0_for_resume(self):
self.pos0 = None
@property
def stored_chain(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
@property
def stored_samples(self):
return self.stored_chain[self.search_parameter_keys]
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return self.sampler.time
@property
def stored_loglike(self):
return self.stored_chain['log_l']
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property
def stored_logprior(self):
return self.stored_chain['log_p']
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def _pos0_shape(self):
return (self.ntemps, self.nwalkers, self.ndim)
def run_sampler(self):
def _initialise_sampler(self):
import ptemcee
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
if self.resume:
self.load_old_chain()
else:
self._set_pos0()
self._sampler = ptemcee.Sampler(
dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
**self.sampler_init_kwargs)
def print_tswap_acceptance_fraction(self):
logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
self.sampler.tswap_acceptance_fraction))
def write_chains_to_file(self, pos, loglike, logpost):
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :])
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.write_chains_to_file(pos, loglike, logpost)
self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
self.result.nburn = self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
self.result.walkers = walkers
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas
self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps)
self.sampler.log_evidence_estimate(
self.sampler.loglikelihood, self.nburn / self.nsteps)
return self.result
......@@ -6,3 +6,4 @@ matplotlib>=2.0
scipy>=0.16
pandas
mock
dill
......@@ -79,6 +79,7 @@ setup(name='bilby',
'future',
'dynesty',
'corner',
'dill',
'numpy>=1.9',
'matplotlib>=2.0',
'pandas',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment