diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index d11334ee6d8174d31c4646bd5ec18cec8d1a3d99..30e7ac6ac81aa8a704b1fc39a2a3b8520176fb77 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,17 +1,19 @@ from __future__ import absolute_import import datetime +import dill import os import sys import pickle import signal +import time import tqdm import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect +from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump from .base_sampler import Sampler, NestedSampler from numpy import linalg @@ -70,14 +72,14 @@ class Dynesty(NestedSampler): check_point_plot: bool, If true, generate a trace plot along with the check-point check_point_delta_t: float (600) - The approximate checkpoint period (in seconds). Should the run be - interrupted, it can be resumed from the last checkpoint. Set to - `None` to turn-off check pointing + The minimum checkpoint period (in seconds). Should the run be + interrupted, it can be resumed from the last checkpoint. n_check_point: int, optional (None) - The number of steps to take before check pointing (override - check_point_delta_t). + The number of steps to take before checking whether to check_point. resume: bool If true, resume run from checkpoint (if available) + exit_code: int + The code which the same exits on if it hasn't finished sampling """ default_kwargs = dict(bound='multi', sample='rwalk', verbose=True, periodic=None, reflective=None, @@ -98,7 +100,7 @@ class Dynesty(NestedSampler): def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, check_point=True, check_point_plot=True, n_check_point=None, - check_point_delta_t=600, resume=True, **kwargs): + check_point_delta_t=600, resume=True, exit_code=130, **kwargs): super(Dynesty, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, @@ -110,19 +112,16 @@ class Dynesty(NestedSampler): self._periodic = list() self._reflective = list() self._apply_dynesty_boundaries() - if self.n_check_point is None: - # If the log_likelihood_eval_time is not calculable then - # check_point is set to False. - if np.isnan(self._log_likelihood_eval_time): - self.check_point = False - n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time) - n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) - self.n_check_point = n_check_point_rnd - logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point)) + if self.n_check_point is None: + self.n_check_point = 1000 + self.check_point_delta_t = check_point_delta_t + logger.info("Checkpoint every check_point_delta_t = {}s" + .format(check_point_delta_t)) self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label) self.sampling_time = datetime.timedelta() + self.exit_code = exit_code signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit) @@ -131,7 +130,8 @@ class Dynesty(NestedSampler): def __getstate__(self): """ For pickle: remove external_sampler, which can be an unpicklable "module" """ state = self.__dict__.copy() - del state['external_sampler'] + if "external_sampler" in state: + del state['external_sampler'] return state @property @@ -319,9 +319,9 @@ class Dynesty(NestedSampler): def _run_external_sampler_with_checkpointing(self): logger.debug("Running sampler with checkpointing") if self.resume: - resume = self.read_saved_state(continuing=True) - if resume: - logger.info('Resuming from previous run.') + resume_file_loaded = self.read_saved_state(continuing=True) + if resume_file_loaded: + logger.info('Resume file successfully loaded.') old_ncall = self.sampler.ncall sampler_kwargs = self.sampler_function_kwargs.copy() @@ -334,8 +334,15 @@ class Dynesty(NestedSampler): break old_ncall = self.sampler.ncall - self.sampler._remove_live_points() - self.write_current_state() + if os.path.isfile(self.resume_file): + last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) + else: + last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds() + if last_checkpoint_s > self.check_point_delta_t: + self.write_current_state() + self.plot_current_state() + if self.sampler.added_live: + self.sampler._remove_live_points() sampler_kwargs['add_live'] = True self._run_nested_wrapper(sampler_kwargs) @@ -348,152 +355,147 @@ class Dynesty(NestedSampler): def read_saved_state(self, continuing=False): """ - Read a saved state of the sampler to disk. + Read a pickled saved state of the sampler to disk. - The required information to reconstruct the state of the run is read - from a pickle file. - This currently adds the whole chain to the sampler. - We then remove the old checkpoint and write all unnecessary items back - to disk. - FIXME: Load only the necessary quantities, rather than read/write? + If the live points are present and the run is continuing + they are removed. + The random state must be reset, as this isn't saved by the pickle. + `nqueue` is set to a negative number to trigger the queue to be + refilled before the first iteration. + The previous run time is set to self. Parameters ---------- - sampler: `dynesty.NestedSampler` - NestedSampler instance to reconstruct from the saved state. continuing: bool Whether the run is continuing or terminating, if True, the loaded state is mostly written back to disk. """ - + from ... import __version__ as bilby_version + from dynesty import __version__ as dynesty_version + versions = dict(bilby=bilby_version, dynesty=dynesty_version) if os.path.isfile(self.resume_file): logger.info("Reading resume file {}".format(self.resume_file)) - try: - with open(self.resume_file, 'rb') as file: - saved = pickle.load(file) - logger.info( - "Succesfuly read resume file {}".format(self.resume_file)) - except EOFError as e: - logger.warning("Resume file reading failed with error {}".format(e)) - return False - - self.sampler.saved_u = list(saved['unit_cube_samples']) - self.sampler.saved_v = list(saved['physical_samples']) - self.sampler.saved_logl = list(saved['sample_likelihoods']) - self.sampler.saved_logvol = list(saved['sample_log_volume']) - self.sampler.saved_logwt = list(saved['sample_log_weights']) - self.sampler.saved_logz = list(saved['cumulative_log_evidence']) - self.sampler.saved_logzvar = list(saved['cumulative_log_evidence_error']) - self.sampler.saved_id = list(saved['id']) - self.sampler.saved_it = list(saved['it']) - self.sampler.saved_nc = list(saved['nc']) - self.sampler.saved_boundidx = list(saved['boundidx']) - self.sampler.saved_bounditer = list(saved['bounditer']) - self.sampler.saved_scale = list(saved['scale']) - self.sampler.saved_h = list(saved['cumulative_information']) - self.sampler.ncall = saved['ncall'] - self.sampler.live_logl = list(saved['live_logl']) - self.sampler.it = saved['iteration'] + 1 - self.sampler.live_u = saved['live_u'] - self.sampler.live_v = saved['live_v'] - self.sampler.nlive = saved['nlive'] - self.sampler.live_bound = saved['live_bound'] - self.sampler.live_it = saved['live_it'] - self.sampler.added_live = saved['added_live'] - self.sampler.bound = saved['bound'] - self.sampler.nbound = saved['nbound'] - self.sampling_time += datetime.timedelta(seconds=saved['sampling_time']) + with open(self.resume_file, 'rb') as file: + sampler = dill.load(file) + + if not hasattr(sampler, "versions"): + logger.warning( + "The resume file {} is corrupted or the version of " + "bilby has changed between runs. This resume file will " + "be ignored." + .format(self.resume_file) + ) + return False + version_warning = ( + "The {code} version has changed between runs. " + "This may cause unpredictable behaviour and/or failure. " + "Old version = {old}, new version = {new}." + + ) + for code in versions: + if not versions[code] == sampler.versions.get(code, None): + logger.warning(version_warning.format( + code=code, + old=sampler.versions.get(code, "None"), + new=versions[code] + )) + del sampler.versions + self.sampler = sampler + if self.sampler.added_live and continuing: + self.sampler._remove_live_points() + self.sampler.nqueue = -1 + self.sampler.rstate = np.random + self.start_time = self.sampler.kwargs.pop("start_time") + self.sampling_time = self.sampler.kwargs.pop("sampling_time") return True - else: - logger.debug( - "No resume file {}".format(self.resume_file)) + logger.info( + "Resume file {} does not exist.".format(self.resume_file)) return False def write_current_state_and_exit(self, signum=None, frame=None): - logger.warning("Run terminated with signal {}".format(signum)) - self.write_current_state(plot=False) - sys.exit(130) + if signum == 14: + logger.info( + "Run interrupted by alarm signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + else: + logger.info( + "Run interrupted by signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + self.write_current_state() + os._exit(self.exit_code) - def write_current_state(self, plot=True): + def write_current_state(self): """ Write the current state of the sampler to disk. - The required information to reconstruct the state of the run are written - to an hdf5 file. - All but the most recent removed live point in the chain are removed from - the sampler to reduce memory usage. - This means it is necessary to not append the first live point to the - file if updating a previous checkpoint. + The sampler is pickle dumped using `dill`. + The sampling time is also stored to get the full CPU time for the run. - Parameters - ---------- - sampler: `dynesty.NestedSampler` - NestedSampler to write to disk. + The check of whether the sampler is picklable is to catch an error + when using pytest. Hopefully, this message won't be triggered during + normal running. """ - check_directory_exists_and_if_not_mkdir(self.outdir) - print("") - logger.info("Writing checkpoint file {}".format(self.resume_file)) + from ... import __version__ as bilby_version + from dynesty import __version__ as dynesty_version + check_directory_exists_and_if_not_mkdir(self.outdir) end_time = datetime.datetime.now() if hasattr(self, 'start_time'): self.sampling_time += end_time - self.start_time self.start_time = end_time - - current_state = dict( - unit_cube_samples=self.sampler.saved_u, - physical_samples=self.sampler.saved_v, - sample_likelihoods=self.sampler.saved_logl, - sample_log_volume=self.sampler.saved_logvol, - sample_log_weights=self.sampler.saved_logwt, - cumulative_log_evidence=self.sampler.saved_logz, - cumulative_log_evidence_error=self.sampler.saved_logzvar, - cumulative_information=self.sampler.saved_h, - id=self.sampler.saved_id, - it=self.sampler.saved_it, - nc=self.sampler.saved_nc, - bound=self.sampler.bound, - nbound=self.sampler.nbound, - boundidx=self.sampler.saved_boundidx, - bounditer=self.sampler.saved_bounditer, - scale=self.sampler.saved_scale, - sampling_time=self.sampling_time.total_seconds() - ) - - current_state.update( - ncall=self.sampler.ncall, live_logl=self.sampler.live_logl, - iteration=self.sampler.it - 1, live_u=self.sampler.live_u, - live_v=self.sampler.live_v, nlive=self.sampler.nlive, - live_bound=self.sampler.live_bound, live_it=self.sampler.live_it, - added_live=self.sampler.added_live + self.sampler.kwargs["sampling_time"] = self.sampling_time + self.sampler.kwargs["start_time"] = self.start_time + self.sampler.versions = dict( + bilby=bilby_version, dynesty=dynesty_version ) + if dill.pickles(self.sampler): + safe_file_dump(self.sampler, self.resume_file, dill) + logger.info("Written checkpoint file {}".format(self.resume_file)) + else: + logger.warning( + "Cannot write pickle resume file! " + "Job will not resume if interrupted." + ) - try: - weights = np.exp(current_state['sample_log_weights'] - - current_state['cumulative_log_evidence'][-1]) - from dynesty.utils import resample_equal - - current_state['posterior'] = resample_equal( - np.array(current_state['physical_samples']), weights) - current_state['search_parameter_keys'] = self.search_parameter_keys - except ValueError: - logger.debug("Unable to create posterior") - - with open(self.resume_file, 'wb') as file: - pickle.dump(current_state, file) - - if plot and self.check_point_plot: + def plot_current_state(self): + if self.check_point_plot: import dynesty.plotting as dyplot labels = [label.replace('_', ' ') for label in self.search_parameter_keys] - filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) try: + filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] fig.tight_layout() fig.savefig(filename) - plt.close('all') except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e: logger.warning(e) logger.warning('Failed to create dynesty state plot at checkpoint') + finally: + plt.close("all") + try: + filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label) + fig, axs = dyplot.runplot(self.sampler.results) + fig.tight_layout() + plt.savefig(filename) + except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e: + logger.warning(e) + logger.warning('Failed to create dynesty run plot at checkpoint') + finally: + plt.close('all') + try: + filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label) + fig, axs = plt.subplots(nrows=3, sharex=True) + for ax, name in zip(axs, ["boundidx", "nc", "scale"]): + ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0") + ax.set_ylabel(name) + axs[-1].set_xlabel("iteration") + fig.tight_layout() + plt.savefig(filename) + except (RuntimeError, ValueError) as e: + logger.warning(e) + logger.warning('Failed to create dynesty stats plot at checkpoint') + finally: + plt.close('all') def generate_trace_plots(self, dynesty_results): check_directory_exists_and_if_not_mkdir(self.outdir) @@ -573,7 +575,6 @@ def sample_rwalk_bilby(args): logl_list = [loglikelihood(v_list[-1])] max_walk_warning = True - drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan while len(u_list) < nact * act: # Propose a direction on the unit n-sphere. diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 6d41a957710d54a9018fc0ea617a01eff870ab8e..71b0b50285d917579c1a66e3b66ded4edcdb1a82 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -1101,6 +1101,25 @@ def reflect(u): return u +def safe_file_dump(data, filename, module): + """ Safely dump data to a .pickle file + + Parameters + ---------- + data: + data to dump + filename: str + The file to dump to + module: pickle, dill + The python module to use + """ + + temp_filename = filename + ".temp" + with open(temp_filename, "wb") as file: + module.dump(data, file) + os.rename(temp_filename, filename) + + def latex_plot_format(func): """ Wrap a plotting function to set rcParams so that text renders nicely with diff --git a/setup.cfg b/setup.cfg index 4a3da3b80cb42da94bbb65788d2e33f5a88a773e..396900f05c8e5032998a697bd738bddaf21f9dd2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] exclude = .git,docs,build,dist,test,*__init__.py max-line-length = 120 -ignore = E129 W504 W605 +ignore = E129 W503 W504 W605 [tool:pytest] addopts =