Skip to content
Snippets Groups Projects
Commit be3a886a authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

basic checkpointing and resuming

parent 747ebded
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
## Unreleased
### Added
- `emcee` now writes all progress to disk and can resume from a previous run.
-
### Changed
......
from __future__ import absolute_import, print_function
import os
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
from ..utils import logger, get_progress_bar
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from .base_sampler import MCMCSampler, SamplerError
......@@ -41,19 +44,23 @@ class Emcee(MCMCSampler):
default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={},
postargs=None, pool=None, live_dangerously=False,
runtime_sortingfn=None, lnprob0=None, rstate0=None,
blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None)
blobs0=None, iterations=100, thin=1, storechain=True,
mh_proposal=None)
def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25,
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
burn_in_act=3, **kwargs):
MCMCSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification,
**kwargs)
MCMCSampler.__init__(
self, likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.resume = resume
self.pos0 = pos0
self.nburn = nburn
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
self._old_chain = None
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
......@@ -168,23 +175,54 @@ class Emcee(MCMCSampler):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._set_pos0()
for _ in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
pass
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
out_file = os.path.join(out_dir, 'chain.dat')
if self.resume:
self.load_old_chain(out_file)
else:
self._set_pos0()
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
points = np.hstack([sample[0], np.array(sample[3])])
# import IPython; IPython.embed()
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.result.sampler_output = np.nan
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None:
chain = np.vstack([self._old_chain[:, :-2],
sampler.chain.reshape((-1, self.ndim))])
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
n_samples = self.nwalkers * self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = sampler.chain[:, self.nburn:, :].reshape((-1, self.ndim))
blobs_flat = np.array(sampler.blobs)[self.nburn:, :, :].reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
self.result.log_likelihood_evaluations = log_likelihoods
self.result.log_prior_evaluations = log_priors
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
......@@ -209,6 +247,20 @@ class Emcee(MCMCSampler):
self.pos0 = [self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment