Skip to content
Snippets Groups Projects
Commit be3a886a authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

basic checkpointing and resuming

parent 747ebded
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
## Unreleased ## Unreleased
### Added ### Added
- `emcee` now writes all progress to disk and can resume from a previous run.
- -
### Changed ### Changed
......
from __future__ import absolute_import, print_function from __future__ import absolute_import, print_function
import os
import numpy as np import numpy as np
from pandas import DataFrame from pandas import DataFrame
from distutils.version import LooseVersion from distutils.version import LooseVersion
from ..utils import logger, get_progress_bar from ..utils import (
logger, get_progress_bar, check_directory_exists_and_if_not_mkdir)
from .base_sampler import MCMCSampler, SamplerError from .base_sampler import MCMCSampler, SamplerError
...@@ -41,19 +44,23 @@ class Emcee(MCMCSampler): ...@@ -41,19 +44,23 @@ class Emcee(MCMCSampler):
default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={}, default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={},
postargs=None, pool=None, live_dangerously=False, postargs=None, pool=None, live_dangerously=False,
runtime_sortingfn=None, lnprob0=None, rstate0=None, runtime_sortingfn=None, lnprob0=None, rstate0=None,
blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None) blobs0=None, iterations=100, thin=1, storechain=True,
mh_proposal=None)
def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, def __init__(self, likelihood, priors, outdir='outdir', label='label',
skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25, use_ratio=False, plot=False, skip_import_verification=False,
pos0=None, nburn=None, burn_in_fraction=0.25, resume=True,
burn_in_act=3, **kwargs): burn_in_act=3, **kwargs):
MCMCSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, MCMCSampler.__init__(
use_ratio=use_ratio, plot=plot, self, likelihood=likelihood, priors=priors, outdir=outdir,
skip_import_verification=skip_import_verification, label=label, use_ratio=use_ratio, plot=plot,
**kwargs) skip_import_verification=skip_import_verification, **kwargs)
self.resume = resume
self.pos0 = pos0 self.pos0 = pos0
self.nburn = nburn self.nburn = nburn
self.burn_in_fraction = burn_in_fraction self.burn_in_fraction = burn_in_fraction
self.burn_in_act = burn_in_act self.burn_in_act = burn_in_act
self._old_chain = None
def _translate_kwargs(self, kwargs): def _translate_kwargs(self, kwargs):
if 'nwalkers' not in kwargs: if 'nwalkers' not in kwargs:
...@@ -168,23 +175,54 @@ class Emcee(MCMCSampler): ...@@ -168,23 +175,54 @@ class Emcee(MCMCSampler):
import emcee import emcee
tqdm = get_progress_bar() tqdm = get_progress_bar()
sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs)
self._set_pos0() out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
for _ in tqdm(sampler.sample(**self.sampler_function_kwargs), out_file = os.path.join(out_dir, 'chain.dat')
total=self.nsteps):
pass if self.resume:
self.load_old_chain(out_file)
else:
self._set_pos0()
check_directory_exists_and_if_not_mkdir(out_dir)
if not os.path.isfile(out_file):
with open(out_file, "w") as ff:
ff.write('walker\t{}\tlog_l'.format(
'\t'.join(self.search_parameter_keys)))
template =\
'{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n'
for sample in tqdm(sampler.sample(**self.sampler_function_kwargs),
total=self.nsteps):
points = np.hstack([sample[0], np.array(sample[3])])
# import IPython; IPython.embed()
with open(out_file, "a") as ff:
for ii, point in enumerate(points):
ff.write(template.format(ii, *point))
self.result.sampler_output = np.nan self.result.sampler_output = np.nan
self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim))) blobs_flat = np.array(sampler.blobs).reshape((-1, 2))
log_likelihoods, log_priors = blobs_flat.T
if self._old_chain is not None:
chain = np.vstack([self._old_chain[:, :-2],
sampler.chain.reshape((-1, self.ndim))])
log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods])
log_ps = np.hstack([self._old_chain[:, -1], log_priors])
self.nsteps = chain.shape[0] // self.nwalkers
else:
chain = sampler.chain.reshape((-1, self.ndim))
log_ls = log_likelihoods
log_ps = log_priors
self.calculate_autocorrelation(chain)
self.print_nburn_logging_info() self.print_nburn_logging_info()
self.result.nburn = self.nburn self.result.nburn = self.nburn
n_samples = self.nwalkers * self.nburn
if self.result.nburn > self.nsteps: if self.result.nburn > self.nsteps:
raise SamplerError( raise SamplerError(
"The run has finished, but the chain is not burned in: " "The run has finished, but the chain is not burned in: "
"`nburn < nsteps`. Try increasing the number of steps.") "`nburn < nsteps`. Try increasing the number of steps.")
self.result.samples = sampler.chain[:, self.nburn:, :].reshape((-1, self.ndim)) self.result.samples = chain[n_samples:, :]
blobs_flat = np.array(sampler.blobs)[self.nburn:, :, :].reshape((-1, 2)) self.result.log_likelihood_evaluations = log_ls[n_samples:]
log_likelihoods, log_priors = blobs_flat.T self.result.log_prior_evaluations = log_ps[n_samples:]
self.result.log_likelihood_evaluations = log_likelihoods
self.result.log_prior_evaluations = log_priors
self.result.walkers = sampler.chain self.result.walkers = sampler.chain
self.result.log_evidence = np.nan self.result.log_evidence = np.nan
self.result.log_evidence_err = np.nan self.result.log_evidence_err = np.nan
...@@ -209,6 +247,20 @@ class Emcee(MCMCSampler): ...@@ -209,6 +247,20 @@ class Emcee(MCMCSampler):
self.pos0 = [self.get_random_draw_from_prior() self.pos0 = [self.get_random_draw_from_prior()
for _ in range(self.nwalkers)] for _ in range(self.nwalkers)]
def load_old_chain(self, file_name=None):
if file_name is None:
out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label))
file_name = os.path.join(out_dir, 'chain.dat')
if os.path.isfile(file_name):
old_chain = np.genfromtxt(file_name, skip_header=1)
self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2])
for ii in range(self.nwalkers)]
self._old_chain = old_chain[:-self.nwalkers + 1, 1:]
logger.info('Resuming from {}'.format(os.path.abspath(file_name)))
else:
logger.warning('Failed to resume. {} not found.'.format(file_name))
self._set_pos0()
def lnpostfn(self, theta): def lnpostfn(self, theta):
log_prior = self.log_prior(theta) log_prior = self.log_prior(theta)
if np.isinf(log_prior): if np.isinf(log_prior):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment