From 9bffa2894cd5b8aa1ee0fcf9b0b8cdf1cdda6045 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 3 Apr 2019 14:36:26 +1100 Subject: [PATCH] Overhaul to using pickles and signals --- bilby/core/sampler/emcee.py | 229 +++++++++++++++++++++------------- bilby/core/sampler/ptemcee.py | 149 +++++++++------------- requirements.txt | 1 + setup.py | 1 + 4 files changed, 197 insertions(+), 183 deletions(-) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 6c1e16972..07d070d0e 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -1,10 +1,14 @@ from __future__ import absolute_import, print_function +from collections import namedtuple import os +import signal +import sys import numpy as np from pandas import DataFrame from distutils.version import LooseVersion +import dill as pickle from ..utils import ( logger, get_progress_bar, check_directory_exists_and_if_not_mkdir) @@ -66,6 +70,9 @@ class Emcee(MCMCSampler): self.burn_in_fraction = burn_in_fraction self.burn_in_act = burn_in_act + signal.signal(signal.SIGTERM, self.checkpoint_and_exit) + signal.signal(signal.SIGINT, self.checkpoint_and_exit) + def _translate_kwargs(self, kwargs): if 'nwalkers' not in kwargs: for equiv in self.nwalkers_equiv_kwargs: @@ -165,66 +172,139 @@ class Emcee(MCMCSampler): def nsteps(self, nsteps): self.kwargs['iterations'] = nsteps - def __getstate__(self): - # In order to be picklable with dill, we need to discard the pool - # object before trying. - d = self.__dict__ - d["_Sampler__kwargs"]["pool"] = None - return d + @property + def stored_chain(self): + """ Read the stored zero-temperature chain data in from disk """ + return np.genfromtxt(self.checkpoint_info.chain_file, names=True) - def set_up_checkpoint(self): - out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) - out_file = os.path.join(out_dir, 'chain.dat') + @property + def stored_samples(self): + """ Returns the samples stored on disk """ + return self.stored_chain[self.search_parameter_keys] - if self.resume: - self.load_old_chain(out_file) - else: - self._set_pos0() + @property + def stored_loglike(self): + """ Returns the log-likelihood stored on disk """ + return self.stored_chain['log_l'] + @property + def stored_logprior(self): + """ Returns the log-prior stored on disk """ + return self.stored_chain['log_p'] + + @property + def checkpoint_info(self): + """ Defines various things related to checkpointing and storing data + + Returns + ------- + checkpoint_info: named_tuple + An object with attributes `sampler_file`, `chain_file`, and + `chain_template`. The first two give paths to where the sampler and + chain data is stored, the last a formatted-str-template with which + to write the chain data to disk + + """ + out_dir = os.path.join( + self.outdir, '{}_{}'.format(self.__class__.__name__, self.label)) check_directory_exists_and_if_not_mkdir(out_dir) - if not os.path.isfile(out_file): - with open(out_file, "w") as ff: - ff.write('walker\t{}\tlog_l\n'.format( + + sampler_file = os.path.join(out_dir, 'sampler.pickle') + + # Initialise chain file + chain_file = os.path.join(out_dir, 'chain.dat') + if not os.path.isfile(chain_file): + with open(chain_file, "w") as ff: + ff.write('walker\t{}\tlog_l\tlog_p\n'.format( '\t'.join(self.search_parameter_keys))) - template =\ + chain_template =\ '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' - return out_file, template + CheckpointInfo = namedtuple( + 'CheckpointInfo', ['sampler_file', 'chain_file', 'chain_template']) - def run_sampler(self): + checkpoint_info = CheckpointInfo( + sampler_file=sampler_file, chain_file=chain_file, + chain_template=chain_template) + + return checkpoint_info + + @property + def sampler_chain(self): + nsteps = self._previous_iterations + return self.sampler.chain[:, :nsteps, :] + + def checkpoint(self): + """ Writes a pickle file of the sampler to disk using dill """ + logger.info("Checkpointing sampler to file {}" + .format(self.checkpoint_info.sampler_file)) + with open(self.checkpoint_info.sampler_file, 'wb') as f: + # Overwrites the stored sampler chain with one that is truncated + # to only the completed steps + self.sampler._chain = self.sampler_chain + pickle.dump(self._sampler, f) + + def checkpoint_and_exit(self, signum, frame): + logger.info("Recieved signal {}".format(signum)) + self.checkpoint() + sys.exit() + + def _initialise_sampler(self): import emcee - tqdm = get_progress_bar() - sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) - out_file, template = self.set_up_checkpoint() + self._sampler = emcee.EnsembleSampler(**self.sampler_init_kwargs) + + def _set_pos0_for_resume(self): + self.pos0 = self.sampler.chain[:, -1, :] + + @property + def sampler(self): + """ Returns the ptemcee sampler object + + If, alrady initialized, returns the stored _sampler value. Otherwise, + first checks if there is a pickle file from which to load. If there is + not, then initialize the sampler and set the initial random draw + + """ + if hasattr(self, '_sampler'): + pass + elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): + with open(self.checkpoint_info.sampler_file, 'rb') as f: + self._sampler = pickle.load(f) + self._set_pos0_for_resume() + else: + self._initialise_sampler() + self._set_pos0() + return self._sampler + def write_chains_to_file(self, sample): + if self.prerelease: + points = np.hstack([sample.coords, sample.blobs]) + else: + points = np.hstack([sample[0], np.array(sample[3])]) + with open(self.checkpoint_info.chain_file, "a") as ff: + for ii, point in enumerate(points): + ff.write(self.checkpoint_info.chain_template.format(ii, *point)) + + def run_sampler(self): + tqdm = get_progress_bar() sampler_function_kwargs = self.sampler_function_kwargs iterations = sampler_function_kwargs.pop('iterations') iterations -= self._previous_iterations + print('pos0', self.pos0) + sampler_function_kwargs['p0'] = self.pos0 + for sample in tqdm( - sampler.sample(iterations=iterations, **sampler_function_kwargs), + self.sampler.sample(iterations=iterations, **sampler_function_kwargs), total=iterations): - if self.prerelease: - points = np.hstack([sample.coords, sample.blobs]) - else: - points = np.hstack([sample[0], np.array(sample[3])]) - with open(out_file, "a") as ff: - for ii, point in enumerate(points): - ff.write(template.format(ii, *point)) + self.write_chains_to_file(sample) self.result.sampler_output = np.nan - blobs_flat = np.array(sampler.blobs).reshape((-1, 2)) + blobs_flat = np.array(self.sampler.blobs).reshape((-1, 2)) log_likelihoods, log_priors = blobs_flat.T - if self._old_chain is not None: - chain = np.vstack([self._old_chain[:, :-2], - sampler.chain.reshape((-1, self.ndim))]) - log_ls = np.hstack([self._old_chain[:, -2], log_likelihoods]) - log_ps = np.hstack([self._old_chain[:, -1], log_priors]) - self.nsteps = chain.shape[0] // self.nwalkers - else: - chain = sampler.chain.reshape((-1, self.ndim)) - log_ls = log_likelihoods - log_ps = log_priors + chain = self.sampler.chain.reshape((-1, self.ndim)) + log_ls = log_likelihoods + log_ps = log_priors self.calculate_autocorrelation(chain) self.print_nburn_logging_info() self.result.nburn = self.nburn @@ -236,13 +316,27 @@ class Emcee(MCMCSampler): self.result.samples = chain[n_samples:, :] self.result.log_likelihood_evaluations = log_ls[n_samples:] self.result.log_prior_evaluations = log_ps[n_samples:] - self.result.walkers = sampler.chain + self.result.walkers = self.sampler.chain self.result.log_evidence = np.nan self.result.log_evidence_err = np.nan return self.result + @property + def _previous_iterations(self): + """ Returns the number of iterations that the sampler has saved + + This is used when loading in a sampler from a pickle file to figure out + how much of the run has already been completed + """ + return len(self.sampler.blobs) + def _draw_pos0_from_prior(self): - return [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + return np.array( + [self.get_random_draw_from_prior() for _ in range(self.nwalkers)]) + + @property + def _pos0_shape(self): + return (self.nwalkers, self.ndim) def _set_pos0(self): if self.pos0 is not None: @@ -250,9 +344,9 @@ class Emcee(MCMCSampler): if isinstance(self.pos0, DataFrame): self.pos0 = self.pos0[self.search_parameter_keys].values elif type(self.pos0) in (list, np.ndarray): - self.pos0 = np.squeeze(self.kwargs['pos0']) + self.pos0 = np.squeeze(self.pos0) - if self.pos0.shape != (self.nwalkers, self.ndim): + if self.pos0.shape != self._pos0_shape: raise ValueError( 'Input pos0 should be of shape ndim, nwalkers') logger.debug("Checking input pos0") @@ -262,51 +356,6 @@ class Emcee(MCMCSampler): logger.debug("Generating initial walker positions from prior") self.pos0 = self._draw_pos0_from_prior() - @property - def _old_chain(self): - try: - old_chain = self.__old_chain - n = old_chain.shape[0] - idx = n - np.mod(n, self.nwalkers) - return old_chain[:idx, :] - except AttributeError: - return None - - @_old_chain.setter - def _old_chain(self, old_chain): - self.__old_chain = old_chain - - @property - def _previous_iterations(self): - if self._old_chain is None: - return 0 - try: - return self._old_chain.shape[0] // self.nwalkers - except AttributeError: - logger.warning( - "Unable to calculate previous iterations from checkpoint," - " defaulting to zero") - return 0 - - def load_old_chain(self, file_name=None): - if file_name is None: - out_dir = os.path.join(self.outdir, 'emcee_{}'.format(self.label)) - file_name = os.path.join(out_dir, 'chain.dat') - if os.path.isfile(file_name): - try: - old_chain = np.genfromtxt(file_name, skip_header=1) - self.pos0 = [np.squeeze(old_chain[-(self.nwalkers - ii), 1:-2]) - for ii in range(self.nwalkers)] - self._old_chain = old_chain[:-self.nwalkers + 1, 1:] - logger.info('Resuming from {}'.format(os.path.abspath(file_name))) - except Exception: - logger.warning('Failed to resume. Corrupt checkpoint file {}.' - .format(file_name)) - self._set_pos0() - else: - logger.warning('Failed to resume. {} not found.'.format(file_name)) - self._set_pos0() - def lnpostfn(self, theta): log_prior = self.log_prior(theta) if np.isinf(log_prior): diff --git a/bilby/core/sampler/ptemcee.py b/bilby/core/sampler/ptemcee.py index 7f02f45e1..3325d3a55 100644 --- a/bilby/core/sampler/ptemcee.py +++ b/bilby/core/sampler/ptemcee.py @@ -1,12 +1,8 @@ from __future__ import absolute_import, division, print_function -import os -from collections import namedtuple - import numpy as np -from ..utils import ( - logger, get_progress_bar, check_directory_exists_and_if_not_mkdir) +from ..utils import logger, get_progress_bar from . import Emcee from .base_sampler import SamplerError @@ -31,12 +27,11 @@ class Ptemcee(Emcee): The number of temperatures used by ptemcee """ - default_kwargs = dict(ntemps=2, nwalkers=500, Tmax=None, betas=None, - threads=1, pool=None, a=2.0, loglargs=[], logpargs=[], - loglkwargs={}, logpkwargs={}, adaptation_lag=10000, - adaptation_time=100, random=None, iterations=100, - thin=1, storechain=True, adapt=True, swap_ratios=False, - ) + default_kwargs = dict( + ntemps=2, nwalkers=500, Tmax=None, betas=None, threads=1, pool=None, + a=2.0, loglargs=[], logpargs=[], loglkwargs={}, logpkwargs={}, + adaptation_lag=10000, adaptation_time=100, random=None, iterations=100, + thin=1, storechain=True, adapt=True, swap_ratios=False) def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, @@ -61,120 +56,88 @@ class Ptemcee(Emcee): if key not in self.sampler_function_kwargs} @property - def checkpoint_info(self): - out_dir = os.path.join(self.outdir, 'ptemcee_{}'.format(self.label)) - chain_file = os.path.join(out_dir, 'chain.dat') - last_pos_file = os.path.join(out_dir, 'last_pos.npy') - - check_directory_exists_and_if_not_mkdir(out_dir) - if not os.path.isfile(chain_file): - with open(chain_file, "w") as ff: - ff.write('walker\t{}\tlog_l\tlog_p\n'.format( - '\t'.join(self.search_parameter_keys))) - template =\ - '{:d}' + '\t{:.9e}' * (len(self.search_parameter_keys) + 2) + '\n' - - CheckpointInfo = namedtuple( - 'CheckpointInfo', ['last_pos_file', 'chain_file', 'template']) - - checkpoint_info = CheckpointInfo( - last_pos_file=last_pos_file, chain_file=chain_file, template=template) - - return checkpoint_info + def ntemps(self): + return self.kwargs['ntemps'] def _draw_pos0_from_prior(self): + # for ptemcee, the pos0 has the shape ntemps, nwalkers, ndim return [[self.get_random_draw_from_prior() for _ in range(self.nwalkers)] for _ in range(self.kwargs['ntemps'])] - @property - def _old_chain(self): - try: - old_chain = self.__old_chain - n = old_chain.shape[0] - idx = n - np.mod(n, self.nwalkers) - return old_chain[:idx] - except AttributeError: - return None - - @_old_chain.setter - def _old_chain(self, old_chain): - self.__old_chain = old_chain + def _set_pos0_for_resume(self): + self.pos0 = None @property - def stored_chain(self): - return np.genfromtxt(self.checkpoint_info.chain_file, names=True) + def _previous_iterations(self): + """ Returns the number of iterations that the sampler has saved - @property - def stored_samples(self): - return self.stored_chain[self.search_parameter_keys] + This is used when loading in a sampler from a pickle file to figure out + how much of the run has already been completed + """ + return self.sampler.time @property - def stored_loglike(self): - return self.stored_chain['log_l'] + def sampler_chain(self): + nsteps = self._previous_iterations + return self.sampler.chain[:, :, :nsteps, :] @property - def stored_logprior(self): - return self.stored_chain['log_p'] - - def load_old_chain(self): - try: - last_pos = np.load(self.checkpoint_info.last_pos_file) - self.pos0 = last_pos - self._old_chain = self.stored_samples - logger.info( - 'Resuming from {} with {} iterations'.format( - self.checkpoint_info.chain_file, - self._previous_iterations)) - except Exception: - logger.info('Unable to resume') - self._set_pos0() + def _pos0_shape(self): + return (self.ntemps, self.nwalkers, self.ndim) - def run_sampler(self): + def _initialise_sampler(self): import ptemcee - tqdm = get_progress_bar() - sampler = ptemcee.Sampler(dim=self.ndim, logl=self.log_likelihood, - logp=self.log_prior, **self.sampler_init_kwargs) - - if self.resume: - self.load_old_chain() - else: - self._set_pos0() + self._sampler = ptemcee.Sampler( + dim=self.ndim, logl=self.log_likelihood, logp=self.log_prior, + **self.sampler_init_kwargs) + + def print_tswap_acceptance_fraction(self): + logger.info("Sampler per-chain tswap acceptance fraction = {}".format( + self.sampler.tswap_acceptance_fraction)) + + def write_chains_to_file(self, pos, loglike, logpost): + with open(self.checkpoint_info.chain_file, "a") as ff: + loglike = np.squeeze(loglike[0, :]) + logprior = np.squeeze(logpost[0, :]) - loglike + for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)): + line = np.concatenate((point, [logl, logp])) + ff.write(self.checkpoint_info.chain_template.format(ii, *line)) + def run_sampler(self): + tqdm = get_progress_bar() sampler_function_kwargs = self.sampler_function_kwargs iterations = sampler_function_kwargs.pop('iterations') iterations -= self._previous_iterations + # main iteration loop for pos, logpost, loglike in tqdm( - sampler.sample(self.pos0, iterations=iterations, - **sampler_function_kwargs), + self.sampler.sample(self.pos0, iterations=iterations, + **sampler_function_kwargs), total=iterations): - np.save(self.checkpoint_info.last_pos_file, pos) - with open(self.checkpoint_info.chain_file, "a") as ff: - loglike = np.squeeze(loglike[:1, :]) - logprior = np.squeeze(logpost[:1, :]) - loglike - for ii, (point, logl, logp) in enumerate(zip(pos[0, :, :], loglike, logprior)): - line = np.concatenate((point, [logl, logp])) - ff.write(self.checkpoint_info.template.format(ii, *line)) - - self.calculate_autocorrelation(sampler.chain.reshape((-1, self.ndim))) + self.write_chains_to_file(pos, loglike, logpost) + + self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) self.result.sampler_output = np.nan self.print_nburn_logging_info() + self.print_tswap_acceptance_fraction() + self.result.nburn = self.nburn if self.result.nburn > self.nsteps: raise SamplerError( "The run has finished, but the chain is not burned in: " "`nburn < nsteps`. Try increasing the number of steps.") - walkers = self.stored_samples.view((float, self.ndim)) - walkers = walkers.reshape(self.nwalkers, self.nsteps, self.ndim) - self.result.walkers = walkers - self.result.samples = walkers[:, self.nburn:, :].reshape((-1, self.ndim)) + + self.result.samples = self.sampler.chain[0, :, self.nburn:, :].reshape( + (-1, self.ndim)) + self.result.walkers = self.sampler.chain[0, :, :, :] + n_samples = self.nwalkers * self.nburn self.result.log_likelihood_evaluations = self.stored_loglike[n_samples:] self.result.log_prior_evaluations = self.stored_logprior[n_samples:] - self.result.betas = sampler.betas + self.result.betas = self.sampler.betas self.result.log_evidence, self.result.log_evidence_err =\ - sampler.log_evidence_estimate( - sampler.loglikelihood, self.nburn / self.nsteps) + self.sampler.log_evidence_estimate( + self.sampler.loglikelihood, self.nburn / self.nsteps) return self.result diff --git a/requirements.txt b/requirements.txt index bb184b6a6..de58a6b16 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ matplotlib>=2.0 scipy>=0.16 pandas mock +dill diff --git a/setup.py b/setup.py index 81575551b..b535cc65e 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ setup(name='bilby', 'future', 'dynesty', 'corner', + 'dill', 'numpy>=1.9', 'matplotlib>=2.0', 'pandas', -- GitLab