Skip to content
Snippets Groups Projects
Commit 57a23737 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Initial work on adding checkpointing to ptemcee

parent 5b39ddf1
No related branches found
No related tags found
No related merge requests found
......@@ -65,7 +65,6 @@ class Emcee(MCMCSampler):
self.nburn = nburn
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
self._old_chain = None
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
......@@ -173,10 +172,7 @@ class Emcee(MCMCSampler):
d["_Sampler__kwargs"]["pool"] = None
return d
def run_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
def set_up_checkpoint(self):
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
out_file = os.path.join(out_dir, 'chain.dat')
......@@ -188,13 +184,26 @@ class Emcee(MCMCSampler):
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l'.format(
ff.write('walker\t{}\tlog_l\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
return out_file, template
def run_sampler(self):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
out_file, template = self.set_up_checkpoint()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
for sample in tqdm(
sampler.sample(iterations=iterations, **sampler_function_kwargs),
total=iterations):
if self.prerelease:
points = np.hstack([sample.coords, sample.blobs])
else:
......@@ -232,6 +241,9 @@ class Emcee(MCMCSampler):
self.result.log_evidence_err = np.nan
return self.result
def _draw_pos0_from_prior(self):
return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]
def _set_pos0(self):
if self.pos0 is not None:
logger.debug("Using given initial positions for walkers")
......@@ -248,19 +260,49 @@ class Emcee(MCMCSampler):
self.check_draw(draw)
else:
logger.debug("Generating initial walker positions from prior")
self.pos0 = [self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
self.pos0 = self._draw_pos0_from_prior()
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx, :]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property
def _previous_iterations(self):
if self._old_chain is None:
return 0
try:
return self._old_chain.shape[0] // self.nwalkers
except AttributeError:
logger.warning(
"Unable to calculate previous iterations from checkpoint,"
" defaulting to zero")
return 0
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
try:
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
except Exception:
logger.warning('Failed to resume. Corrupt checkpoint file {}.'
.format(file_name))
self._set_pos0()
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
......
from __future__ import absolute_import, division, print_function
import os
from collections import namedtuple
import numpy as np
from ..utils import get_progress_bar
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from . import Emcee
from .base_sampler import SamplerError
......@@ -36,13 +40,14 @@ class Ptemcee(Emcee):
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
nburn=None, burn_in_fraction=0.25, burn_in_act=3, **kwargs):
nburn=None, burn_in_fraction=0.25, burn_in_act=3, resume=True,
**kwargs):
Emcee.__init__(
self, likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification,
nburn=nburn, burn_in_fraction=burn_in_fraction,
burn_in_act=burn_in_act, **kwargs)
burn_in_act=burn_in_act, resume=True, **kwargs)
@property
def sampler_function_kwargs(self):
......@@ -55,23 +60,102 @@ class Ptemcee(Emcee):
for key, value in self.kwargs.items()
if key not in self.sampler_function_kwargs}
@property
def checkpoint_info(self):
out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label))
chain_file = os.path.join(out_dir, 'chain.dat')
last_pos_file = os.path.join(out_dir, 'last_pos.npy')
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(chain_file):
with open(chain_file, "w") as ff:
ff.write('walker\t{}\tlog_l\tlog_p\n'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
CheckpointInfo = namedtuple(
'CheckpointInfo', ['last_pos_file', 'chain_file', 'template'])
checkpoint_info = CheckpointInfo(
last_pos_file=last_pos_file, chain_file=chain_file, template=template)
return checkpoint_info
def _draw_pos0_from_prior(self):
return [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
@property
def _old_chain(self):
try:
old_chain = self.__old_chain
n = old_chain.shape[0]
idx = n - np.mod(n, self.nwalkers)
return old_chain[:idx]
except AttributeError:
return None
@_old_chain.setter
def _old_chain(self, old_chain):
self.__old_chain = old_chain
@property
def stored_chain(self):
return np.genfromtxt(self.checkpoint_info.chain_file, names=True)
@property
def stored_samples(self):
return self.stored_chain[self.search_parameter_keys]
@property
def stored_loglike(self):
return self.stored_chain['log_l']
@property
def stored_logprior(self):
return self.stored_chain['log_p']
def load_old_chain(self):
try:
last_pos = np.load(self.checkpoint_info.last_pos_file)
self.pos0 = last_pos
self._old_chain = self.stored_samples
logger.info(
'Resuming from {} with {} iterations'.format(
self.checkpoint_info.chain_file,
self._previous_iterations))
except Exception:
logger.info('Unable to resume')
self._set_pos0()
def run_sampler(self):
import ptemcee
tqdm = get_progress_bar()
sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood,
logp=self.log_prior, **self.sampler_init_kwargs)
self.pos0 = [[self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
for _ in range(self.kwargs['ntemps'])]
log_likelihood_evaluations = []
log_prior_evaluations = []
if self.resume:
self.load_old_chain()
else:
self._set_pos0()
sampler_function_kwargs = self.sampler_function_kwargs
iterations = sampler_function_kwargs.pop('iterations')
iterations -= self._previous_iterations
for pos, logpost, loglike in tqdm(
sampler.sample(self.pos0, **self.sampler_function_kwargs),
total=self.nsteps):
log_likelihood_evaluations.append(loglike)
log_prior_evaluations.append(logpost - loglike)
pass
sampler.sample(self.pos0, iterations=iterations,
**sampler_function_kwargs),
total=iterations):
np.save(self.checkpoint_info.last_pos_file, pos)
with open(self.checkpoint_info.chain_file, "a") as ff:
loglike = np.squeeze(loglike[:1, :])
logprior = np.squeeze(logpost[:1, :]) - loglike
for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)):
line = np.concatenate((point, [logl, logp]))
ff.write(self.checkpoint_info.template.format(ii, *line))
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
self.result.sampler_output = np.nan
......@@ -81,16 +165,16 @@ class Ptemcee(Emcee):
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape(
(-1, self.ndim))
self.result.log_likelihood_evaluations = np.array(
log_likelihood_evaluations)[self.nburn:, 0, :].reshape((-1))
self.result.log_prior_evaluations = np.array(
log_prior_evaluations)[self.nburn:, 0, :].reshape((-1))
walkers = self.stored_samples.view((float, self.ndim))
walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim)
self.result.walkers = walkers
self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim))
n_samples = self.nwalkers * self.nburn
self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:]
self.result.log_prior_evaluations = self.stored_logprior[n_samples:]
self.result.betas = sampler.betas
self.result.log_evidence, self.result.log_evidence_err =\
sampler.log_evidence_estimate(
sampler.loglikelihood, self.nburn / self.nsteps)
self.result.walkers = sampler.chain[0, :, :, :]
return self.result
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment