Skip to content
Snippets Groups Projects
Commit 40c349e6 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'emcee_resume' into 'master'

basic checkpointing and resuming

Closes #282

See merge request lscsoft/bilby!346
parents 747ebded be3a886a
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
## Unreleased
### Added
- `emcee` now writes all progress to disk and can resume from a previous run.
-
### Changed
......
from __future__ import absolute_import, print_function
import os
import numpy as np
from pandas import DataFrame
from distutils.version import LooseVersion
from ..utils import logger, get_progress_bar
from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from .base_sampler import MCMCSampler, SamplerError
......@@ -41,19 +44,23 @@ class Emcee(MCMCSampler):
default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={},
postargs=None, pool=None, live_dangerously=False,
runtime_sortingfn=None, lnprob0=None, rstate0=None,
blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None)
blobs0=None, iterations=100, thin=1, storechain=True,
mh_proposal=None)
def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False,
skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25,
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
burn_in_act=3, **kwargs):
MCMCSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label,
use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification,
**kwargs)
MCMCSampler.__init__(
self, likelihood=likelihood, priors=priors, outdir=outdir,
label=label, use_ratio=use_ratio, plot=plot,
skip_import_verification=skip_import_verification, **kwargs)
self.resume = resume
self.pos0 = pos0
self.nburn = nburn
self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act
self._old_chain = None
def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs:
......@@ -168,23 +175,54 @@ class Emcee(MCMCSampler):
import emcee
tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._set_pos0()
for _ in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
pass
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
out_file = os.path.join(out_dir, 'chain.dat')
if self.resume:
self.load_old_chain(out_file)
else:
self._set_pos0()
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
points = np.hstack([sample[0], np.array(sample[3])])
# import IPython; IPython.embed()
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.result.sampler_output = np.nan
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim)))
blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None:
chain = np.vstack([self._old_chain[:, :-2],
sampler.chain.reshape((-1, self.ndim))])
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info()
self.result.nburn = self.nburn
n_samples = self.nwalkers * self.nburn
if self.result.nburn > self.nsteps:
raise SamplerError(
"The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = sampler.chain[:, self.nburn:, :].reshape((-1, self.ndim))
blobs_flat = np.array(sampler.blobs)[self.nburn:, :, :].reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
self.result.log_likelihood_evaluations = log_likelihoods
self.result.log_prior_evaluations = log_priors
self.result.samples = chain[n_samples:, :]
self.result.log_likelihood_evaluations = log_ls[n_samples:]
self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.walkers = sampler.chain
self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan
......@@ -209,6 +247,20 @@ class Emcee(MCMCSampler):
self.pos0 = [self.get_random_draw_from_prior()
for _ in range(self.nwalkers)]
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta):
log_prior = self.log_prior(theta)
if np.isinf(log_prior):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment