From 57a2373714090ecac3ea9ec1e38dd7b0e112fb9c Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Tue, 2 Apr 2019 16:03:54 +1100 Subject: [PATCH] Initial work on adding checkpointing to ptemcee --- bilby/core/sampler/emcee.py | 72 ++++++++++++++++---- bilby/core/sampler/ptemcee.py | 124 ++++++++++++++++++++++++++++------ 2 files changed, 161 insertions(+), 35 deletions(-) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 3117a6c7a..6c1e16972 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -65,7 +65,6 @@ class Emcee(MCMCSampler): self.nburn = nburn self.burn_in_fraction = burn_in_fraction self.burn_in_act = burn_in_act - self._old_chain = None def _translate_kwargs(self, kwargs): if 'nwalkers' not in kwargs: @@ -173,10 +172,7 @@ class Emcee(MCMCSampler): d["_Sampler__kwargs"]["pool"] = None return d - def run_sampler(self): - import emcee - tqdm = get_progress_bar() - sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) + def set_up_checkpoint(self): out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) out_file = os.path.join(out_dir, 'chain.dat') @@ -188,13 +184,26 @@ class Emcee(MCMCSampler): check_directory_exists_and_if_not_mkdir(out_dir) if not os.path.isfile(out_file): with open(out_file, "w") as ff: - ff.write('walker\t{}\tlog_l'.format( + ff.write('walker\t{}\tlog_l\n'.format( '\t'.join(self.search_parameter_keys))) template =\ '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' - for sample in tqdm(sampler.sample(**self.sampler_function_kwargs), - total=self.nsteps): + return out_file, template + + def run_sampler(self): + import emcee + tqdm = get_progress_bar() + sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) + out_file, template = self.set_up_checkpoint() + + sampler_function_kwargs = self.sampler_function_kwargs + iterations = sampler_function_kwargs.pop('iterations') + iterations -= self._previous_iterations + + for sample in tqdm( + sampler.sample(iterations=iterations, **sampler_function_kwargs), + total=iterations): if self.prerelease: points = np.hstack([sample.coords, sample.blobs]) else: @@ -232,6 +241,9 @@ class Emcee(MCMCSampler): self.result.log_evidence_err = np.nan return self.result + def _draw_pos0_from_prior(self): + return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + def _set_pos0(self): if self.pos0 is not None: logger.debug("Using given initial positions for walkers") @@ -248,19 +260,49 @@ class Emcee(MCMCSampler): self.check_draw(draw) else: logger.debug("Generating initial walker positions from prior") - self.pos0 = [self.get_random_draw_from_prior() - for _ in range(self.nwalkers)] + self.pos0 = self._draw_pos0_from_prior() + + @property + def _old_chain(self): + try: + old_chain = self.__old_chain + n = old_chain.shape[0] + idx = n - np.mod(n, self.nwalkers) + return old_chain[:idx, :] + except AttributeError: + return None + + @_old_chain.setter + def _old_chain(self, old_chain): + self.__old_chain = old_chain + + @property + def _previous_iterations(self): + if self._old_chain is None: + return 0 + try: + return self._old_chain.shape[0] // self.nwalkers + except AttributeError: + logger.warning( + "Unable to calculate previous iterations from checkpoint," + " defaulting to zero") + return 0 def load_old_chain(self, file_name=None): if file_name is None: out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) file_name = os.path.join(out_dir, 'chain.dat') if os.path.isfile(file_name): - old_chain = np.genfromtxt(file_name, skip_header=1) - self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2]) - for ii in range(self.nwalkers)] - self._old_chain = old_chain[:-self.nwalkers + 1, 1:] - logger.info('Resuming from {}'.format(os.path.abspath(file_name))) + try: + old_chain = np.genfromtxt(file_name, skip_header=1) + self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2]) + for ii in range(self.nwalkers)] + self._old_chain = old_chain[:-self.nwalkers + 1, 1:] + logger.info('Resuming from {}'.format(os.path.abspath(file_name))) + except Exception: + logger.warning('Failed to resume. Corrupt checkpoint file {}.' + .format(file_name)) + self._set_pos0() else: logger.warning('Failed to resume. {} not found.'.format(file_name)) self._set_pos0() diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index e0c2401ca..7f02f45e1 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -1,8 +1,12 @@ from __future__ import absolute_import, division, print_function +import os +from collections import namedtuple + import numpy as np -from ..utils import get_progress_bar +from ..utils import ( + logger, get_progress_bar, check_directory_exists_and_if_not_mkdir) from . import Emcee from .base_sampler import SamplerError @@ -36,13 +40,14 @@ class Ptemcee(Emcee): def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, - nburn=None, burn_in_fraction=0.25, burn_in_act=3, **kwargs): + nburn=None, burn_in_fraction=0.25, burn_in_act=3, resume=True, + **kwargs): Emcee.__init__( self, likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, nburn=nburn, burn_in_fraction=burn_in_fraction, - burn_in_act=burn_in_act, **kwargs) + burn_in_act=burn_in_act, resume=True, **kwargs) @property def sampler_function_kwargs(self): @@ -55,23 +60,102 @@ class Ptemcee(Emcee): for key, value in self.kwargs.items() if key not in self.sampler_function_kwargs} + @property + def checkpoint_info(self): + out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label)) + chain_file = os.path.join(out_dir, 'chain.dat') + last_pos_file = os.path.join(out_dir, 'last_pos.npy') + + check_directory_exists_and_if_not_mkdir(out_dir) + if not os.path.isfile(chain_file): + with open(chain_file, "w") as ff: + ff.write('walker\t{}\tlog_l\tlog_p\n'.format( + '\t'.join(self.search_parameter_keys))) + template =\ + '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' + + CheckpointInfo = namedtuple( + 'CheckpointInfo', ['last_pos_file', 'chain_file', 'template']) + + checkpoint_info = CheckpointInfo( + last_pos_file=last_pos_file, chain_file=chain_file, template=template) + + return checkpoint_info + + def _draw_pos0_from_prior(self): + return [[self.get_random_draw_from_prior() + for _ in range(self.nwalkers)] + for _ in range(self.kwargs['ntemps'])] + + @property + def _old_chain(self): + try: + old_chain = self.__old_chain + n = old_chain.shape[0] + idx = n - np.mod(n, self.nwalkers) + return old_chain[:idx] + except AttributeError: + return None + + @_old_chain.setter + def _old_chain(self, old_chain): + self.__old_chain = old_chain + + @property + def stored_chain(self): + return np.genfromtxt(self.checkpoint_info.chain_file, names=True) + + @property + def stored_samples(self): + return self.stored_chain[self.search_parameter_keys] + + @property + def stored_loglike(self): + return self.stored_chain['log_l'] + + @property + def stored_logprior(self): + return self.stored_chain['log_p'] + + def load_old_chain(self): + try: + last_pos = np.load(self.checkpoint_info.last_pos_file) + self.pos0 = last_pos + self._old_chain = self.stored_samples + logger.info( + 'Resuming from {} with {} iterations'.format( + self.checkpoint_info.chain_file, + self._previous_iterations)) + except Exception: + logger.info('Unable to resume') + self._set_pos0() + def run_sampler(self): import ptemcee tqdm = get_progress_bar() sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, **self.sampler_init_kwargs) - self.pos0 = [[self.get_random_draw_from_prior() - for _ in range(self.nwalkers)] - for _ in range(self.kwargs['ntemps'])] - log_likelihood_evaluations = [] - log_prior_evaluations = [] + if self.resume: + self.load_old_chain() + else: + self._set_pos0() + + sampler_function_kwargs = self.sampler_function_kwargs + iterations = sampler_function_kwargs.pop('iterations') + iterations -= self._previous_iterations + for pos, logpost, loglike in tqdm( - sampler.sample(self.pos0, **self.sampler_function_kwargs), - total=self.nsteps): - log_likelihood_evaluations.append(loglike) - log_prior_evaluations.append(logpost - loglike) - pass + sampler.sample(self.pos0, iterations=iterations, + **sampler_function_kwargs), + total=iterations): + np.save(self.checkpoint_info.last_pos_file, pos) + with open(self.checkpoint_info.chain_file, "a") as ff: + loglike = np.squeeze(loglike[:1, :]) + logprior = np.squeeze(logpost[:1, :]) - loglike + for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)): + line = np.concatenate((point, [logl, logp])) + ff.write(self.checkpoint_info.template.format(ii, *line)) self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim))) self.result.sampler_output = np.nan @@ -81,16 +165,16 @@ class Ptemcee(Emcee): raise SamplerError( "The run has finished, but the chain is not burned in: " "`nburn < nsteps`. Try increasing the number of steps.") - self.result.samples = sampler.chain[0, :, self.nburn:, :].reshape( - (-1, self.ndim)) - self.result.log_likelihood_evaluations = np.array( - log_likelihood_evaluations)[self.nburn:, 0, :].reshape((-1)) - self.result.log_prior_evaluations = np.array( - log_prior_evaluations)[self.nburn:, 0, :].reshape((-1)) + walkers = self.stored_samples.view((float, self.ndim)) + walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim) + self.result.walkers = walkers + self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim)) + n_samples = self.nwalkers * self.nburn + self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:] + self.result.log_prior_evaluations = self.stored_logprior[n_samples:] self.result.betas = sampler.betas self.result.log_evidence, self.result.log_evidence_err =\ sampler.log_evidence_estimate( sampler.loglikelihood, self.nburn / self.nsteps) - self.result.walkers = sampler.chain[0, :, :, :] return self.result -- GitLab