Skip to content
Snippets Groups Projects
Commit 9bffa289 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Overhaul to using pickles and signals

parent 57a23737
No related branches found
No related tags found
1 merge request!423Improvements to checkpointing for emcee/ptemcee
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
from collections import namedtuple
import os import os
import signal
import sys
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
from distutils.version import LooseVersion from distutils.version import LooseVersion
import dill as pickle
from ..utils import ( from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir) logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
...@@ -66,6 +70,9 @@ class Emcee(MCMCSampler): ...@@ -66,6 +70,9 @@ class Emcee(MCMCSampler):
self.burn_in_fraction = burn_in_fraction self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act self.burn_in_act = burn_in_act
signal.signal(signal.SIGTERM, self.checkpoint_and_exit)
signal.signal(signal.SIGINT, self.checkpoint_and_exit)
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs: if 'nwalkers' not in kwargs:
for equiv in self.nwalkers_equiv_kwargs: for equiv in self.nwalkers_equiv_kwargs:
...@@ -165,66 +172,139 @@ class Emcee(MCMCSampler): ...@@ -165,66 +172,139 @@ class Emcee(MCMCSampler):
def nsteps(self, nsteps): def nsteps(self, nsteps):
self.kwargs['iterations'] = nsteps self.kwargs['iterations'] = nsteps
def __getstate__(self): @property
# In order to be picklable with dill, we need to discard the pool def stored_chain(self):
# object before trying. """ Read the stored zero-temperature chain data in from disk """
d = self.__dict__ return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
d["_Sampler__kwargs"]["pool"] = None
return d
def set_up_checkpoint(self): @property
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) def stored_samples(self):
out_file = os.path.join(out_dir, 'chain.dat') """ Returns the samples stored on disk """
return self.stored_chain[self.search_parameter_keys]
if self.resume: @property
self.load_old_chain(out_file) def stored_loglike(self):
else: """ Returns the log-likelihood stored on disk """
self._set_pos0() return self.stored_chain['log_l']
@property
def stored_logprior(self):
""" Returns the log-prior stored on disk """
return self.stored_chain['log_p']
@property
def checkpoint_info(self):
""" Defines various things related to checkpointing and storing data
Returns
-------
checkpoint_info: named_tuple
An object with attributes `sampler_file`, `chain_file`, and
`chain_template`. The first two give paths to where the sampler and
chain data is stored, the last a formatted-str-template with which
to write the chain data to disk
"""
out_dir = os.path.join(
self.outdir, '{}_{}'.format(self.__class__.__name__, self.label))
check_directory_exists_and_if_not_mkdir(out_dir) check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff: sampler_file = os.path.join(out_dir, 'sampler.pickle')
ff.write('walker\t{}\tlog_l\n'.format(
# Initialise chain file
chain_file = os.path.join(out_dir, 'chain.dat')
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys))) '\t'.join(self.search_parameter_keys)))
template =\ chain_template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
return out_file, template CheckpointInfo = namedtuple(
'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template'])
def run_sampler(self): checkpoint_info = CheckpointInfo(
sampler_file=sampler_file, chain_file=chain_file,
chain_template=chain_template)
return checkpoint_info
@property
def sampler_chain(self):
nsteps = self._previous_iterations
return self.sampler.chain[:, :nsteps, :]
def checkpoint(self):
""" Writes a pickle file of the sampler to disk using dill """
logger.info("Checkpointing sampler to file {}"
.format(self.checkpoint_info.sampler_file))
with open(self.checkpoint_info.sampler_file, 'wb') as f:
# Overwrites the stored sampler chain with one that is truncated
# to only the completed steps
self.sampler._chain = self.sampler_chain
pickle.dump(self._sampler, f)
def checkpoint_and_exit(self, signum, frame):
logger.info("Recieved signal {}".format(signum))
self.checkpoint()
sys.exit()
def _initialise_sampler(self):
import emcee import emcee
tqdm = get_progress_bar() self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
out_file, template = self.set_up_checkpoint() def _set_pos0_for_resume(self):
self.pos0 = self.sampler.chain[:, -1, :]
@property
def sampler(self):
""" Returns the ptemcee sampler object
If, alrady initialized, returns the stored _sampler value. Otherwise,
first checks if there is a pickle file from which to load. If there is
not, then initialize the sampler and set the initial random draw
"""
if hasattr(self, '_sampler'):
pass
elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file):
with open(self.checkpoint_info.sampler_file, 'rb') as f:
self._sampler = pickle.load(f)
self._set_pos0_for_resume()
else:
self._initialise_sampler()
self._set_pos0()
return self._sampler
def write_chains_to_file(self, sample):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(self.checkpoint_info.chain_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(self.checkpoint_info.chain_template.format(ii, *point))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations') iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations iterations -= self._previous_iterations
print('pos0', self.pos0)
sampler_function_kwargs['p0'] = self.pos0
for sample in tqdm( for sample in tqdm(
sampler.sample(iterations=iterations, **sampler_function_kwargs), self.sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations): total=iterations):
if self.prerelease: self.write_chains_to_file(sample)
points = np.hstack([sample.coords, sample.blobs])
else:
points = np.hstack([sample[0], np.array(sample[3])])
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.result.sampler_output = np.nan self.result.sampler_output = np.nan
blobs_flat = np.array(sampler.blobs).reshape((-1, 2)) blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None: chain = self.sampler.chain.reshape((-1, self.ndim))
chain = np.vstack([self._old_chain[:, :-2], log_ls = log_likelihoods
sampler.chain.reshape((-1, self.ndim))]) log_ps = log_priors
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain) self.calculate_autocorrelation(chain)
self.print_nburn_logging_info() self.print_nburn_logging_info()
self.result.nburn = self.nburn self.result.nburn = self.nburn
...@@ -236,13 +316,27 @@ class Emcee(MCMCSampler): ...@@ -236,13 +316,27 @@ class Emcee(MCMCSampler):
self.result.samples = chain[n_samples:, :] self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:] self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:] self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = sampler.chain self.result.walkers = self.sampler.chain
self.result.log_evidence = np.nan self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan self.result.log_evidence_err = np.nan
return self.result return self.result
@property
def _previous_iterations(self):
""" Returns the number of iterations that the sampler has saved
This is used when loading in a sampler from a pickle file to figure out
how much of the run has already been completed
"""
return len(self.sampler.blobs)
def _draw_pos0_from_prior(self): def _draw_pos0_from_prior(self):
return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] return np.array(
[self.get_random_draw_from_prior() for _ in range(self.nwalkers)])
@property
def _pos0_shape(self):
return (self.nwalkers, self.ndim)
def _set_pos0(self): def _set_pos0(self):
if self.pos0 is not None: if self.pos0 is not None:
...@@ -250,9 +344,9 @@ class Emcee(MCMCSampler): ...@@ -250,9 +344,9 @@ class Emcee(MCMCSampler):
if isinstance(self.pos0, DataFrame): if isinstance(self.pos0, DataFrame):
self.pos0 = self.pos0[self.search_parameter_keys].values self.pos0 = self.pos0[self.search_parameter_keys].values
elif type(self.pos0) in (list, np.ndarray): elif type(self.pos0) in (list, np.ndarray):
self.pos0 = np.squeeze(self.kwargs['pos0']) self.pos0 = np.squeeze(self.pos0)
if self.pos0.shape != (self.nwalkers, self.ndim): if self.pos0.shape != self._pos0_shape:
raise ValueError( raise ValueError(
'Input pos0 should be of shape ndim, nwalkers') 'Input pos0 should be of shape ndim, nwalkers')
logger.debug("Checking input pos0") logger.debug("Checking input pos0")
...@@ -262,51 +356,6 @@ class Emcee(MCMCSampler): ...@@ -262,51 +356,6 @@ class Emcee(MCMCSampler):
logger.debug("Generating initial walker positions from prior") logger.debug("Generating initial walker positions from prior")
self.pos0 = self._draw_pos0_from_prior() self.pos0 = self._draw_pos0_from_prior()
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx, :]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property
def _previous_iterations(self):
if self._old_chain is None:
return 0
try:
return self._old_chain.shape[0] // self.nwalkers
except AttributeError:
logger.warning(
"Unable to calculate previous iterations from checkpoint,"
" defaulting to zero")
return 0
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
try:
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
except Exception:
logger.warning('Failed to resume. Corrupt checkpoint file {}.'
.format(file_name))
self._set_pos0()
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta): def lnpostfn(self, theta):
log_prior = self.log_prior(theta) log_prior = self.log_prior(theta)
if np.isinf(log_prior): if np.isinf(log_prior):
......
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np import numpy as np
from ..utils import ( from ..utils import logger, get_progress_bar
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from . import Emcee from . import Emcee
from .base_sampler import SamplerError from .base_sampler import SamplerError
...@@ -31,12 +27,11 @@ class Ptemcee(Emcee): ...@@ -31,12 +27,11 @@ class Ptemcee(Emcee):
The number of temperatures used by ptemcee The number of temperatures used by ptemcee
""" """
default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None, default_kwargs = dict(
threads=1, pool=None, a=2.0, loglargs=[], logpargs=[], ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None,
loglkwargs={}, logpkwargs={}, adaptation_lag=10000, a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={},
adaptation_time=100, random=None, iterations=100, adaptation_lag=10000, adaptation_time=100, random=None, iterations=100,
thin=1, storechain=True, adapt=True, swap_ratios=False, thin=1, storechain=True, adapt=True, swap_ratios=False)
)
def __init__(self, likelihood, priors, outdir='outdir', label='label', def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False, use_ratio=False, plot=False, skip_import_verification=False,
...@@ -61,120 +56,88 @@ class Ptemcee(Emcee): ...@@ -61,120 +56,88 @@ class Ptemcee(Emcee):
if key not in self.sampler_function_kwargs} if key not in self.sampler_function_kwargs}
@property @property
def checkpoint_info(self): def ntemps(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label)) return self.kwargs['ntemps']
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def _draw_pos0_from_prior(self): def _draw_pos0_from_prior(self):
# for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim
return [[self.get_random_draw_from_prior() return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)] for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])] for _ in range(self.kwargs['ntemps'])]
@property def _set_pos0_for_resume(self):
def _old_chain(self): self.pos0 = None
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property @property
def stored_chain(self): def _previous_iterations(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True) """ Returns the number of iterations that the sampler has saved
@property This is used when loading in a sampler from a pickle file to figure out
def stored_samples(self): how much of the run has already been completed
return self.stored_chain[self.search_parameter_keys] """
return self.sampler.time
@property @property
def stored_loglike(self): def sampler_chain(self):
return self.stored_chain['log_l'] nsteps = self._previous_iterations
return self.sampler.chain[:, :, :nsteps, :]
@property @property
def stored_logprior(self): def _pos0_shape(self):
return self.stored_chain['log_p'] return (self.ntemps, self.nwalkers, self.ndim)
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def run_sampler(self): def _initialise_sampler(self):
import ptemcee import ptemcee
tqdm = get_progress_bar() self._sampler = ptemcee.Sampler(
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood, dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior,
logp=self.log_prior, **self.sampler_init_kwargs) **self.sampler_init_kwargs)
if self.resume: def print_tswap_acceptance_fraction(self):
self.load_old_chain() logger.info("Sampler per-chain tswap acceptance fraction = {}".format(
else: self.sampler.tswap_acceptance_fraction))
self._set_pos0()
def write_chains_to_file(self, pos, loglike, logpost):
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[0, :])
logprior = np.squeeze(logpost[0, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.chain_template.format(ii, *line))
def run_sampler(self):
tqdm = get_progress_bar()
sampler_function_kwargs = self.sampler_function_kwargs sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations') iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations iterations -= self._previous_iterations
# main iteration loop
for pos, logpost, loglike in tqdm( for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, iterations=iterations, self.sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs), **sampler_function_kwargs),
total=iterations): total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos) self.write_chains_to_file(pos, loglike, logpost)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :]) self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim)))
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan self.result.sampler_output = np.nan
self.print_nburn_logging_info() self.print_nburn_logging_info()
self.print_tswap_acceptance_fraction()
self.result.nburn = self.nburn self.result.nburn = self.nburn
if self.result.nburn > self.nsteps: if self.result.nburn > self.nsteps:
raise SamplerError( raise SamplerError(
"The run has finished, but the chain is not burned in: " "The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.") "`nburn < nsteps`. Try increasing the number of steps.")
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim) self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape(
self.result.walkers = walkers (-1, self.ndim))
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim)) self.result.walkers = self.sampler.chain[0, :, :, :]
n_samples = self.nwalkers * self.nburn n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:] self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:] self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas self.result.betas = self.sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\ self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate( self.sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps) self.sampler.loglikelihood, self.nburn / self.nsteps)
return self.result return self.result
...@@ -6,3 +6,4 @@ matplotlib>=2.0 ...@@ -6,3 +6,4 @@ matplotlib>=2.0
scipy>=0.16 scipy>=0.16
pandas pandas
mock mock
dill
...@@ -79,6 +79,7 @@ setup(name='bilby', ...@@ -79,6 +79,7 @@ setup(name='bilby',
'future', 'future',
'dynesty', 'dynesty',
'corner', 'corner',
'dill',
'numpy>=1.9', 'numpy>=1.9',
'matplotlib>=2.0', 'matplotlib>=2.0',
'pandas', 'pandas',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment