diff --git a/bilby/core/prior/dict.py b/bilby/core/prior/dict.py index ca2eed960589fe5cb8d0bdfc05ca470fa00dfce5..a375625b3082ecf54d9005ff85fef766671386b2 100644 --- a/bilby/core/prior/dict.py +++ b/bilby/core/prior/dict.py @@ -39,6 +39,7 @@ class PriorDict(dict): self.from_file(filename) elif dictionary is not None: raise ValueError("PriorDict input dictionary not understood") + self._cached_normalizations = {} self.convert_floats_to_delta_functions() @@ -383,6 +384,27 @@ class PriorDict(dict): if not isinstance(self[key], Constraint)} return all_samples + def normalize_constraint_factor(self, keys): + if keys in self._cached_normalizations.keys(): + return self._cached_normalizations[keys] + else: + min_accept = 1000 + sampling_chunk = 5000 + samples = self.sample_subset(keys=keys, size=sampling_chunk) + keep = np.atleast_1d(self.evaluate_constraints(samples)) + if len(keep) == 1: + return 1 + all_samples = {key: np.array([]) for key in keys} + while np.count_nonzero(keep) < min_accept: + samples = self.sample_subset(keys=keys, size=sampling_chunk) + for key in samples: + all_samples[key] = np.hstack( + [all_samples[key], samples[key].flatten()]) + keep = np.array(self.evaluate_constraints(all_samples), dtype=bool) + factor = len(keep) / np.count_nonzero(keep) + self._cached_normalizations[keys] = factor + return factor + def prob(self, sample, **kwargs): """ @@ -401,6 +423,7 @@ class PriorDict(dict): prob = np.product([self[key].prob(sample[key]) for key in sample], **kwargs) + ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(prob == 0.): return prob else: @@ -412,7 +435,7 @@ class PriorDict(dict): else: constrained_prob = np.zeros_like(prob) keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_prob[keep] = prob[keep] + constrained_prob[keep] = prob[keep] * ratio return constrained_prob def ln_prob(self, sample, axis=None): @@ -434,6 +457,7 @@ class PriorDict(dict): ln_prob = np.sum([self[key].ln_prob(sample[key]) for key in sample], axis=axis) + ratio = self.normalize_constraint_factor(tuple(sample.keys())) if np.all(np.isinf(ln_prob)): return ln_prob else: @@ -445,7 +469,7 @@ class PriorDict(dict): else: constrained_ln_prob = -np.inf * np.ones_like(ln_prob) keep = np.array(self.evaluate_constraints(sample), dtype=bool) - constrained_ln_prob[keep] = ln_prob[keep] + constrained_ln_prob[keep] = ln_prob[keep] + np.log(ratio) return constrained_ln_prob def rescale(self, keys, theta): diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index d11334ee6d8174d31c4646bd5ec18cec8d1a3d99..30e7ac6ac81aa8a704b1fc39a2a3b8520176fb77 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -1,17 +1,19 @@ from __future__ import absolute_import import datetime +import dill import os import sys import pickle import signal +import time import tqdm import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect +from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect, safe_file_dump from .base_sampler import Sampler, NestedSampler from numpy import linalg @@ -70,14 +72,14 @@ class Dynesty(NestedSampler): check_point_plot: bool, If true, generate a trace plot along with the check-point check_point_delta_t: float (600) - The approximate checkpoint period (in seconds). Should the run be - interrupted, it can be resumed from the last checkpoint. Set to - `None` to turn-off check pointing + The minimum checkpoint period (in seconds). Should the run be + interrupted, it can be resumed from the last checkpoint. n_check_point: int, optional (None) - The number of steps to take before check pointing (override - check_point_delta_t). + The number of steps to take before checking whether to check_point. resume: bool If true, resume run from checkpoint (if available) + exit_code: int + The code which the same exits on if it hasn't finished sampling """ default_kwargs = dict(bound='multi', sample='rwalk', verbose=True, periodic=None, reflective=None, @@ -98,7 +100,7 @@ class Dynesty(NestedSampler): def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, check_point=True, check_point_plot=True, n_check_point=None, - check_point_delta_t=600, resume=True, **kwargs): + check_point_delta_t=600, resume=True, exit_code=130, **kwargs): super(Dynesty, self).__init__(likelihood=likelihood, priors=priors, outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, skip_import_verification=skip_import_verification, @@ -110,19 +112,16 @@ class Dynesty(NestedSampler): self._periodic = list() self._reflective = list() self._apply_dynesty_boundaries() - if self.n_check_point is None: - # If the log_likelihood_eval_time is not calculable then - # check_point is set to False. - if np.isnan(self._log_likelihood_eval_time): - self.check_point = False - n_check_point_raw = (check_point_delta_t / self._log_likelihood_eval_time) - n_check_point_rnd = int(float("{:1.0g}".format(n_check_point_raw))) - self.n_check_point = n_check_point_rnd - logger.info("Checkpoint every n_check_point = {}".format(self.n_check_point)) + if self.n_check_point is None: + self.n_check_point = 1000 + self.check_point_delta_t = check_point_delta_t + logger.info("Checkpoint every check_point_delta_t = {}s" + .format(check_point_delta_t)) self.resume_file = '{}/{}_resume.pickle'.format(self.outdir, self.label) self.sampling_time = datetime.timedelta() + self.exit_code = exit_code signal.signal(signal.SIGTERM, self.write_current_state_and_exit) signal.signal(signal.SIGINT, self.write_current_state_and_exit) @@ -131,7 +130,8 @@ class Dynesty(NestedSampler): def __getstate__(self): """ For pickle: remove external_sampler, which can be an unpicklable "module" """ state = self.__dict__.copy() - del state['external_sampler'] + if "external_sampler" in state: + del state['external_sampler'] return state @property @@ -319,9 +319,9 @@ class Dynesty(NestedSampler): def _run_external_sampler_with_checkpointing(self): logger.debug("Running sampler with checkpointing") if self.resume: - resume = self.read_saved_state(continuing=True) - if resume: - logger.info('Resuming from previous run.') + resume_file_loaded = self.read_saved_state(continuing=True) + if resume_file_loaded: + logger.info('Resume file successfully loaded.') old_ncall = self.sampler.ncall sampler_kwargs = self.sampler_function_kwargs.copy() @@ -334,8 +334,15 @@ class Dynesty(NestedSampler): break old_ncall = self.sampler.ncall - self.sampler._remove_live_points() - self.write_current_state() + if os.path.isfile(self.resume_file): + last_checkpoint_s = time.time() - os.path.getmtime(self.resume_file) + else: + last_checkpoint_s = (datetime.datetime.now() - self.start_time).total_seconds() + if last_checkpoint_s > self.check_point_delta_t: + self.write_current_state() + self.plot_current_state() + if self.sampler.added_live: + self.sampler._remove_live_points() sampler_kwargs['add_live'] = True self._run_nested_wrapper(sampler_kwargs) @@ -348,152 +355,147 @@ class Dynesty(NestedSampler): def read_saved_state(self, continuing=False): """ - Read a saved state of the sampler to disk. + Read a pickled saved state of the sampler to disk. - The required information to reconstruct the state of the run is read - from a pickle file. - This currently adds the whole chain to the sampler. - We then remove the old checkpoint and write all unnecessary items back - to disk. - FIXME: Load only the necessary quantities, rather than read/write? + If the live points are present and the run is continuing + they are removed. + The random state must be reset, as this isn't saved by the pickle. + `nqueue` is set to a negative number to trigger the queue to be + refilled before the first iteration. + The previous run time is set to self. Parameters ---------- - sampler: `dynesty.NestedSampler` - NestedSampler instance to reconstruct from the saved state. continuing: bool Whether the run is continuing or terminating, if True, the loaded state is mostly written back to disk. """ - + from ... import __version__ as bilby_version + from dynesty import __version__ as dynesty_version + versions = dict(bilby=bilby_version, dynesty=dynesty_version) if os.path.isfile(self.resume_file): logger.info("Reading resume file {}".format(self.resume_file)) - try: - with open(self.resume_file, 'rb') as file: - saved = pickle.load(file) - logger.info( - "Succesfuly read resume file {}".format(self.resume_file)) - except EOFError as e: - logger.warning("Resume file reading failed with error {}".format(e)) - return False - - self.sampler.saved_u = list(saved['unit_cube_samples']) - self.sampler.saved_v = list(saved['physical_samples']) - self.sampler.saved_logl = list(saved['sample_likelihoods']) - self.sampler.saved_logvol = list(saved['sample_log_volume']) - self.sampler.saved_logwt = list(saved['sample_log_weights']) - self.sampler.saved_logz = list(saved['cumulative_log_evidence']) - self.sampler.saved_logzvar = list(saved['cumulative_log_evidence_error']) - self.sampler.saved_id = list(saved['id']) - self.sampler.saved_it = list(saved['it']) - self.sampler.saved_nc = list(saved['nc']) - self.sampler.saved_boundidx = list(saved['boundidx']) - self.sampler.saved_bounditer = list(saved['bounditer']) - self.sampler.saved_scale = list(saved['scale']) - self.sampler.saved_h = list(saved['cumulative_information']) - self.sampler.ncall = saved['ncall'] - self.sampler.live_logl = list(saved['live_logl']) - self.sampler.it = saved['iteration'] + 1 - self.sampler.live_u = saved['live_u'] - self.sampler.live_v = saved['live_v'] - self.sampler.nlive = saved['nlive'] - self.sampler.live_bound = saved['live_bound'] - self.sampler.live_it = saved['live_it'] - self.sampler.added_live = saved['added_live'] - self.sampler.bound = saved['bound'] - self.sampler.nbound = saved['nbound'] - self.sampling_time += datetime.timedelta(seconds=saved['sampling_time']) + with open(self.resume_file, 'rb') as file: + sampler = dill.load(file) + + if not hasattr(sampler, "versions"): + logger.warning( + "The resume file {} is corrupted or the version of " + "bilby has changed between runs. This resume file will " + "be ignored." + .format(self.resume_file) + ) + return False + version_warning = ( + "The {code} version has changed between runs. " + "This may cause unpredictable behaviour and/or failure. " + "Old version = {old}, new version = {new}." + + ) + for code in versions: + if not versions[code] == sampler.versions.get(code, None): + logger.warning(version_warning.format( + code=code, + old=sampler.versions.get(code, "None"), + new=versions[code] + )) + del sampler.versions + self.sampler = sampler + if self.sampler.added_live and continuing: + self.sampler._remove_live_points() + self.sampler.nqueue = -1 + self.sampler.rstate = np.random + self.start_time = self.sampler.kwargs.pop("start_time") + self.sampling_time = self.sampler.kwargs.pop("sampling_time") return True - else: - logger.debug( - "No resume file {}".format(self.resume_file)) + logger.info( + "Resume file {} does not exist.".format(self.resume_file)) return False def write_current_state_and_exit(self, signum=None, frame=None): - logger.warning("Run terminated with signal {}".format(signum)) - self.write_current_state(plot=False) - sys.exit(130) + if signum == 14: + logger.info( + "Run interrupted by alarm signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + else: + logger.info( + "Run interrupted by signal {}: checkpoint and exit on {}" + .format(signum, self.exit_code)) + self.write_current_state() + os._exit(self.exit_code) - def write_current_state(self, plot=True): + def write_current_state(self): """ Write the current state of the sampler to disk. - The required information to reconstruct the state of the run are written - to an hdf5 file. - All but the most recent removed live point in the chain are removed from - the sampler to reduce memory usage. - This means it is necessary to not append the first live point to the - file if updating a previous checkpoint. + The sampler is pickle dumped using `dill`. + The sampling time is also stored to get the full CPU time for the run. - Parameters - ---------- - sampler: `dynesty.NestedSampler` - NestedSampler to write to disk. + The check of whether the sampler is picklable is to catch an error + when using pytest. Hopefully, this message won't be triggered during + normal running. """ - check_directory_exists_and_if_not_mkdir(self.outdir) - print("") - logger.info("Writing checkpoint file {}".format(self.resume_file)) + from ... import __version__ as bilby_version + from dynesty import __version__ as dynesty_version + check_directory_exists_and_if_not_mkdir(self.outdir) end_time = datetime.datetime.now() if hasattr(self, 'start_time'): self.sampling_time += end_time - self.start_time self.start_time = end_time - - current_state = dict( - unit_cube_samples=self.sampler.saved_u, - physical_samples=self.sampler.saved_v, - sample_likelihoods=self.sampler.saved_logl, - sample_log_volume=self.sampler.saved_logvol, - sample_log_weights=self.sampler.saved_logwt, - cumulative_log_evidence=self.sampler.saved_logz, - cumulative_log_evidence_error=self.sampler.saved_logzvar, - cumulative_information=self.sampler.saved_h, - id=self.sampler.saved_id, - it=self.sampler.saved_it, - nc=self.sampler.saved_nc, - bound=self.sampler.bound, - nbound=self.sampler.nbound, - boundidx=self.sampler.saved_boundidx, - bounditer=self.sampler.saved_bounditer, - scale=self.sampler.saved_scale, - sampling_time=self.sampling_time.total_seconds() - ) - - current_state.update( - ncall=self.sampler.ncall, live_logl=self.sampler.live_logl, - iteration=self.sampler.it - 1, live_u=self.sampler.live_u, - live_v=self.sampler.live_v, nlive=self.sampler.nlive, - live_bound=self.sampler.live_bound, live_it=self.sampler.live_it, - added_live=self.sampler.added_live + self.sampler.kwargs["sampling_time"] = self.sampling_time + self.sampler.kwargs["start_time"] = self.start_time + self.sampler.versions = dict( + bilby=bilby_version, dynesty=dynesty_version ) + if dill.pickles(self.sampler): + safe_file_dump(self.sampler, self.resume_file, dill) + logger.info("Written checkpoint file {}".format(self.resume_file)) + else: + logger.warning( + "Cannot write pickle resume file! " + "Job will not resume if interrupted." + ) - try: - weights = np.exp(current_state['sample_log_weights'] - - current_state['cumulative_log_evidence'][-1]) - from dynesty.utils import resample_equal - - current_state['posterior'] = resample_equal( - np.array(current_state['physical_samples']), weights) - current_state['search_parameter_keys'] = self.search_parameter_keys - except ValueError: - logger.debug("Unable to create posterior") - - with open(self.resume_file, 'wb') as file: - pickle.dump(current_state, file) - - if plot and self.check_point_plot: + def plot_current_state(self): + if self.check_point_plot: import dynesty.plotting as dyplot labels = [label.replace('_', ' ') for label in self.search_parameter_keys] - filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) try: + filename = "{}/{}_checkpoint_trace.png".format(self.outdir, self.label) fig = dyplot.traceplot(self.sampler.results, labels=labels)[0] fig.tight_layout() fig.savefig(filename) - plt.close('all') except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e: logger.warning(e) logger.warning('Failed to create dynesty state plot at checkpoint') + finally: + plt.close("all") + try: + filename = "{}/{}_checkpoint_run.png".format(self.outdir, self.label) + fig, axs = dyplot.runplot(self.sampler.results) + fig.tight_layout() + plt.savefig(filename) + except (RuntimeError, np.linalg.linalg.LinAlgError, ValueError) as e: + logger.warning(e) + logger.warning('Failed to create dynesty run plot at checkpoint') + finally: + plt.close('all') + try: + filename = "{}/{}_checkpoint_stats.png".format(self.outdir, self.label) + fig, axs = plt.subplots(nrows=3, sharex=True) + for ax, name in zip(axs, ["boundidx", "nc", "scale"]): + ax.plot(getattr(self.sampler, f"saved_{name}"), color="C0") + ax.set_ylabel(name) + axs[-1].set_xlabel("iteration") + fig.tight_layout() + plt.savefig(filename) + except (RuntimeError, ValueError) as e: + logger.warning(e) + logger.warning('Failed to create dynesty stats plot at checkpoint') + finally: + plt.close('all') def generate_trace_plots(self, dynesty_results): check_directory_exists_and_if_not_mkdir(self.outdir) @@ -573,7 +575,6 @@ def sample_rwalk_bilby(args): logl_list = [loglikelihood(v_list[-1])] max_walk_warning = True - drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan while len(u_list) < nact * act: # Propose a direction on the unit n-sphere. diff --git a/bilby/core/sampler/kombine.py b/bilby/core/sampler/kombine.py index cd37070511ecf805e563d3d43a71b1a13dbb5e01..48e85342a8fd319453fb47f305970b920e724791 100644 --- a/bilby/core/sampler/kombine.py +++ b/bilby/core/sampler/kombine.py @@ -160,8 +160,8 @@ class Kombine(Emcee): self.result.nburn = self.nburn if self.result.nburn > self.nsteps: raise SamplerError( - "The run has finished, but the chain is not burned in: " - "`nburn < nsteps`. Try increasing the number of steps.") + "The run has finished, but the chain is not burned in: `nburn < nsteps` ({} < {}). Try increasing the " + "number of steps.".format(self.result.nburn, self.nsteps)) tmp_chain = self.sampler.chain[self.nburn:, :, :].copy() self.result.samples = tmp_chain.reshape((-1, self.ndim)) blobs = np.array(self.sampler.blobs) diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 6d41a957710d54a9018fc0ea617a01eff870ab8e..71b0b50285d917579c1a66e3b66ded4edcdb1a82 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -1101,6 +1101,25 @@ def reflect(u): return u +def safe_file_dump(data, filename, module): + """ Safely dump data to a .pickle file + + Parameters + ---------- + data: + data to dump + filename: str + The file to dump to + module: pickle, dill + The python module to use + """ + + temp_filename = filename + ".temp" + with open(temp_filename, "wb") as file: + module.dump(data, file) + os.rename(temp_filename, filename) + + def latex_plot_format(func): """ Wrap a plotting function to set rcParams so that text renders nicely with diff --git a/docs/prior.txt b/docs/prior.txt index 8f89a0ad06317cb1852599da9eaa8fb01bebe2f3..8e5f2944d1d315edc23cd5af2cd38394dcb828a6 100644 --- a/docs/prior.txt +++ b/docs/prior.txt @@ -178,8 +178,9 @@ First thing: define a function which generates z=x-y from x and y. ------- dict: Dictionary with constraint parameter 'z' added. """ - parameters['z'] = parameters['x'] - parameters['y'] - return parameters + converted_parameters = parameters.copy() + converted_parameters['z'] = parameters['x'] - parameters['y'] + return converted_parameters Create our prior: diff --git a/setup.cfg b/setup.cfg index f303e1bace88e331ca49d46c6bccc3249397df30..ad60888c125af2650afb6c2b495a3cbf720123e0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] exclude = .git,docs,build,dist,test,*__init__.py max-line-length = 120 -ignore = E129 W504 W605 E203 +ignore = E129 W503 W504 W605 E203 [tool:pytest] addopts = diff --git a/test/prior_test.py b/test/prior_test.py index 832956928bc2c44217395afd287cc564868d4325..a3c5a312eddb1c2438295b8e29f2fb0c9f16a3c7 100644 --- a/test/prior_test.py +++ b/test/prior_test.py @@ -907,6 +907,26 @@ class TestPriorDict(unittest.TestCase): self.assertFalse(self.prior_set_from_dict.test_redundancy(key=key)) +class TestConstraintPriorNormalisation(unittest.TestCase): + def setUp(self): + self.priors = dict(mass_1=bilby.core.prior.Uniform(name='mass_1', minimum=5, maximum=10, unit='$M_{\odot}$', + boundary=None), + mass_2=bilby.core.prior.Uniform(name='mass_2', minimum=5, maximum=10, unit='$M_{\odot}$', + boundary=None), + mass_ratio=bilby.core.prior.Constraint(name='mass_ratio', minimum=0, maximum=1)) + self.priors = bilby.core.prior.PriorDict(self.priors) + + def test_prob_integrate_to_one(self): + keys = ['mass_1', 'mass_2', 'mass_ratio'] + n = 5000 + samples = self.priors.sample_subset(keys=keys, size=n) + prob = self.priors.prob(samples, axis=0) + dm1 = self.priors['mass_1'].maximum - self.priors['mass_1'].minimum + dm2 = self.priors['mass_2'].maximum - self.priors['mass_2'].minimum + integral = np.sum(prob * (dm1 * dm2)) / len(samples['mass_1']) + self.assertAlmostEqual(1, integral, 5) + + class TestLoadPrior(unittest.TestCase): def test_load_prior_with_float(self): filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), diff --git a/test/sampler_test.py b/test/sampler_test.py index a157dddb05b5987bcb69725234cbf8dc8f406477..ddd54b72bd1a749f2c0fe2985a4db18826d9e120 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -549,7 +549,7 @@ class TestRunningSamplers(unittest.TestCase): def test_run_kombine(self): _ = bilby.run_sampler( likelihood=self.likelihood, priors=self.priors, sampler='kombine', - iterations=2500, nwalkers=100, save=False) + iterations=1000, nwalkers=100, save=False, autoburnin=True) def test_run_nestle(self): _ = bilby.run_sampler(