From be3a886a1666e60d288811b604ab549997d858e0 Mon Sep 17 00:00:00 2001 From: Colm Talbot <colm.talbot@ligo.org> Date: Thu, 7 Feb 2019 20:33:26 -0600 Subject: [PATCH] basic checkpointing and resuming --- CHANGELOG.md | 1 + bilby/core/sampler/emcee.py | 88 +++++++++++++++++++++++++++++-------- 2 files changed, 71 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f22560ff..2a08a648 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## Unreleased ### Added +- `emcee` now writes all progress to disk and can resume from a previous run. - ### Changed diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 56e6887d..b52c2bee 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -1,10 +1,13 @@ from __future__ import absolute_import, print_function +import os + import numpy as np from pandas import DataFrame from distutils.version import LooseVersion -from ..utils import logger, get_progress_bar +from ..utils import ( + logger, get_progress_bar, check_directory_exists_and_if_not_mkdir) from .base_sampler import MCMCSampler, SamplerError @@ -41,19 +44,23 @@ class Emcee(MCMCSampler): default_kwargs = dict(nwalkers=500, a=2, args=[], kwargs={}, postargs=None, pool=None, live_dangerously=False, runtime_sortingfn=None, lnprob0=None, rstate0=None, - blobs0=None, iterations=100, thin=1, storechain=True, mh_proposal=None) + blobs0=None, iterations=100, thin=1, storechain=True, + mh_proposal=None) - def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, - skip_import_verification=False, pos0=None, nburn=None, burn_in_fraction=0.25, + def __init__(self, likelihood, priors, outdir='outdir', label='label', + use_ratio=False, plot=False, skip_import_verification=False, + pos0=None, nburn=None, burn_in_fraction=0.25, resume=True, burn_in_act=3, **kwargs): - MCMCSampler.__init__(self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, - use_ratio=use_ratio, plot=plot, - skip_import_verification=skip_import_verification, - **kwargs) + MCMCSampler.__init__( + self, likelihood=likelihood, priors=priors, outdir=outdir, + label=label, use_ratio=use_ratio, plot=plot, + skip_import_verification=skip_import_verification, **kwargs) + self.resume = resume self.pos0 = pos0 self.nburn = nburn self.burn_in_fraction = burn_in_fraction self.burn_in_act = burn_in_act + self._old_chain = None def _translate_kwargs(self, kwargs): if 'nwalkers' not in kwargs: @@ -168,23 +175,54 @@ class Emcee(MCMCSampler): import emcee tqdm = get_progress_bar() sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) - self._set_pos0() - for _ in tqdm(sampler.sample(**self.sampler_function_kwargs), - total=self.nsteps): - pass + out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) + out_file = os.path.join(out_dir, 'chain.dat') + + if self.resume: + self.load_old_chain(out_file) + else: + self._set_pos0() + + check_directory_exists_and_if_not_mkdir(out_dir) + if not os.path.isfile(out_file): + with open(out_file, "w") as ff: + ff.write('walker\t{}\tlog_l'.format( + '\t'.join(self.search_parameter_keys))) + template =\ + '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' + + for sample in tqdm(sampler.sample(**self.sampler_function_kwargs), + total=self.nsteps): + points = np.hstack([sample[0], np.array(sample[3])]) + # import IPython; IPython.embed() + with open(out_file, "a") as ff: + for ii, point in enumerate(points): + ff.write(template.format(ii, *point)) + self.result.sampler_output = np.nan - self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim))) + blobs_flat = np.array(sampler.blobs).reshape((-1, 2)) + log_likelihoods, log_priors = blobs_flat.T + if self._old_chain is not None: + chain = np.vstack([self._old_chain[:, :-2], + sampler.chain.reshape((-1, self.ndim))]) + log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods]) + log_ps = np.hstack([self._old_chain[:, -1], log_priors]) + self.nsteps = chain.shape[0] // self.nwalkers + else: + chain = sampler.chain.reshape((-1, self.ndim)) + log_ls = log_likelihoods + log_ps = log_priors + self.calculate_autocorrelation(chain) self.print_nburn_logging_info() self.result.nburn = self.nburn + n_samples = self.nwalkers * self.nburn if self.result.nburn > self.nsteps: raise SamplerError( "The run has finished, but the chain is not burned in: " "`nburn < nsteps`. Try increasing the number of steps.") - self.result.samples = sampler.chain[:, self.nburn:, :].reshape((-1, self.ndim)) - blobs_flat = np.array(sampler.blobs)[self.nburn:, :, :].reshape((-1, 2)) - log_likelihoods, log_priors = blobs_flat.T - self.result.log_likelihood_evaluations = log_likelihoods - self.result.log_prior_evaluations = log_priors + self.result.samples = chain[n_samples:, :] + self.result.log_likelihood_evaluations = log_ls[n_samples:] + self.result.log_prior_evaluations = log_ps[n_samples:] self.result.walkers = sampler.chain self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan @@ -209,6 +247,20 @@ class Emcee(MCMCSampler): self.pos0 = [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + def load_old_chain(self, file_name=None): + if file_name is None: + out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) + file_name = os.path.join(out_dir, 'chain.dat') + if os.path.isfile(file_name): + old_chain = np.genfromtxt(file_name, skip_header=1) + self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2]) + for ii in range(self.nwalkers)] + self._old_chain = old_chain[:-self.nwalkers + 1, 1:] + logger.info('Resuming from {}'.format(os.path.abspath(file_name))) + else: + logger.warning('Failed to resume. {} not found.'.format(file_name)) + self._set_pos0() + def lnpostfn(self, theta): log_prior = self.log_prior(theta) if np.isinf(log_prior): -- GitLab