diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index dde8d920cd52a03b7d3e647f29c5333378af20db..eb52cf2891feeddf4ff9b3e32a47dfa67809f45a 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -22,15 +22,29 @@ from .pymultinest import Pymultinest from .ultranest import Ultranest from .fake_sampler import FakeSampler from .dnest4 import DNest4 +from .zeus import Zeus from bilby.bilby_mcmc import Bilby_MCMC from . import proposal IMPLEMENTED_SAMPLERS = { - 'bilby_mcmc': Bilby_MCMC, 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty, - 'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai, - 'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, - 'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord, - 'ultranest': Ultranest, 'fake_sampler': FakeSampler} + "bilby_mcmc": Bilby_MCMC, + "cpnest": Cpnest, + "dnest4": DNest4, + "dynamic_dynesty": DynamicDynesty, + "dynesty": Dynesty, + "emcee": Emcee, + "kombine": Kombine, + "nessai": Nessai, + "nestle": Nestle, + "ptemcee": Ptemcee, + "ptmcmcsampler": PTMCMCSampler, + "pymc3": Pymc3, + "pymultinest": Pymultinest, + "pypolychord": PyPolyChord, + "ultranest": Ultranest, + "zeus": Zeus, + "fake_sampler": FakeSampler, +} if command_line_args.sampler_help: sampler = command_line_args.sampler_help @@ -40,20 +54,36 @@ if command_line_args.sampler_help: print(sampler_class.__doc__) else: if sampler == "None": - print('For help with a specific sampler, call sampler-help with ' - 'the name of the sampler') + print( + "For help with a specific sampler, call sampler-help with " + "the name of the sampler" + ) else: - print('Requested sampler {} not implemented'.format(sampler)) - print('Available samplers = {}'.format(IMPLEMENTED_SAMPLERS)) + print("Requested sampler {} not implemented".format(sampler)) + print("Available samplers = {}".format(IMPLEMENTED_SAMPLERS)) sys.exit() -def run_sampler(likelihood, priors=None, label='label', outdir='outdir', - sampler='dynesty', use_ratio=None, injection_parameters=None, - conversion_function=None, plot=False, default_priors_file=None, - clean=None, meta_data=None, save=True, gzip=False, - result_class=None, npool=1, **kwargs): +def run_sampler( + likelihood, + priors=None, + label="label", + outdir="outdir", + sampler="dynesty", + use_ratio=None, + injection_parameters=None, + conversion_function=None, + plot=False, + default_priors_file=None, + clean=None, + meta_data=None, + save=True, + gzip=False, + result_class=None, + npool=1, + **kwargs +): """ The primary interface to easy parameter estimation @@ -116,13 +146,13 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', """ logger.info( - "Running for label '{}', output will be saved to '{}'".format( - label, outdir)) + "Running for label '{}', output will be saved to '{}'".format(label, outdir) + ) if clean: command_line_args.clean = clean if command_line_args.clean: - kwargs['resume'] = False + kwargs["resume"] = False from . import IMPLEMENTED_SAMPLERS @@ -143,11 +173,12 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', # Generate the meta-data if not given and append the likelihood meta_data if meta_data is None: meta_data = dict() - meta_data['likelihood'] = likelihood.meta_data + meta_data["likelihood"] = likelihood.meta_data meta_data["loaded_modules"] = loaded_modules_dict() if command_line_args.bilby_zero_likelihood_mode: from bilby.core.likelihood import ZeroLikelihood + likelihood = ZeroLikelihood(likelihood) if isinstance(sampler, Sampler): @@ -156,24 +187,39 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', if sampler.lower() in IMPLEMENTED_SAMPLERS: sampler_class = IMPLEMENTED_SAMPLERS[sampler.lower()] sampler = sampler_class( - likelihood, priors=priors, outdir=outdir, label=label, - injection_parameters=injection_parameters, meta_data=meta_data, - use_ratio=use_ratio, plot=plot, result_class=result_class, - npool=npool, **kwargs) + likelihood, + priors=priors, + outdir=outdir, + label=label, + injection_parameters=injection_parameters, + meta_data=meta_data, + use_ratio=use_ratio, + plot=plot, + result_class=result_class, + npool=npool, + **kwargs + ) else: print(IMPLEMENTED_SAMPLERS) - raise ValueError( - "Sampler {} not yet implemented".format(sampler)) + raise ValueError("Sampler {} not yet implemented".format(sampler)) elif inspect.isclass(sampler): sampler = sampler.__init__( - likelihood, priors=priors, - outdir=outdir, label=label, use_ratio=use_ratio, plot=plot, - injection_parameters=injection_parameters, meta_data=meta_data, - npool=npool, **kwargs) + likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + injection_parameters=injection_parameters, + meta_data=meta_data, + npool=npool, + **kwargs + ) else: raise ValueError( "Provided sampler should be a Sampler object or name of a known " - "sampler: {}.".format(', '.join(IMPLEMENTED_SAMPLERS.keys()))) + "sampler: {}.".format(", ".join(IMPLEMENTED_SAMPLERS.keys())) + ) if sampler.cached_result: logger.warning("Using cached result") @@ -191,31 +237,31 @@ def run_sampler(likelihood, priors=None, label='label', outdir='outdir', result.sampling_time = end_time - start_time elif isinstance(result.sampling_time, float): result.sampling_time = datetime.timedelta(result.sampling_time) - logger.info('Sampling time: {}'.format(result.sampling_time)) + logger.info("Sampling time: {}".format(result.sampling_time)) # Convert sampling time into seconds result.sampling_time = result.sampling_time.total_seconds() if sampler.use_ratio: result.log_noise_evidence = likelihood.noise_log_likelihood() result.log_bayes_factor = result.log_evidence - result.log_evidence = \ - result.log_bayes_factor + result.log_noise_evidence + result.log_evidence = result.log_bayes_factor + result.log_noise_evidence else: result.log_noise_evidence = likelihood.noise_log_likelihood() - result.log_bayes_factor = \ - result.log_evidence - result.log_noise_evidence + result.log_bayes_factor = result.log_evidence - result.log_noise_evidence # Initial save of the sampler in case of failure in post-processing if save: result.save_to_file(extension=save, gzip=gzip) if None not in [result.injection_parameters, conversion_function]: - result.injection_parameters = conversion_function( - result.injection_parameters) + result.injection_parameters = conversion_function(result.injection_parameters) - result.samples_to_posterior(likelihood=likelihood, priors=result.priors, - conversion_function=conversion_function, - npool=npool) + result.samples_to_posterior( + likelihood=likelihood, + priors=result.priors, + conversion_function=conversion_function, + npool=npool, + ) if save: # The overwrite here ensures we overwrite the initially stored data @@ -232,5 +278,7 @@ def _check_marginalized_parameters_not_sampled(likelihood, priors): if key in priors: if not isinstance(priors[key], (float, DeltaFunction)): raise SamplingMarginalisedParameterError( - "Likelihood is {} marginalized but you are trying to sample in {}. " - .format(key, key)) + "Likelihood is {} marginalized but you are trying to sample in {}. ".format( + key, key + ) + ) diff --git a/bilby/core/sampler/emcee.py b/bilby/core/sampler/emcee.py index 48516aa94e4d0770cd14adcbe31f73ed0d63753c..1f9ba786ccd1230db6f073defa90c1ca78bf9630 100644 --- a/bilby/core/sampler/emcee.py +++ b/bilby/core/sampler/emcee.py @@ -23,7 +23,7 @@ class Emcee(MCMCSampler): Parameters ========== - nwalkers: int, (100) + nwalkers: int, (500) The number of walkers nsteps: int, (100) The number of steps @@ -273,7 +273,7 @@ class Emcee(MCMCSampler): @property def sampler(self): - """ Returns the ptemcee sampler object + """ Returns the emcee sampler object If, already initialized, returns the stored _sampler value. Otherwise, first checks if there is a pickle file from which to load. If there is diff --git a/bilby/core/sampler/zeus.py b/bilby/core/sampler/zeus.py new file mode 100644 index 0000000000000000000000000000000000000000..78c3529ea00c1c9ee518a8698e2f70bc29fe194d --- /dev/null +++ b/bilby/core/sampler/zeus.py @@ -0,0 +1,392 @@ +import os +import signal +import shutil +import sys +from collections import namedtuple +from shutil import copyfile + +import numpy as np +from pandas import DataFrame + +from ..utils import logger, check_directory_exists_and_if_not_mkdir +from .base_sampler import MCMCSampler, SamplerError + + +class Zeus(MCMCSampler): + """bilby wrapper for Zeus (https://zeus-mcmc.readthedocs.io/) + + All positional and keyword arguments (i.e., the args and kwargs) passed to + `run_sampler` will be propagated to `zeus.EnsembleSampler`, see + documentation for that class for further help. Under Other Parameters, we + list commonly used kwargs and the bilby defaults. + + Parameters + ========== + nwalkers: int, (500) + The number of walkers + nsteps: int, (100) + The number of steps + nburn: int (None) + If given, the fixed number of steps to discard as burn-in. These will + be discarded from the total number of steps set by `nsteps` and + therefore the value must be greater than `nsteps`. Else, nburn is + estimated from the autocorrelation time + burn_in_fraction: float, (0.25) + The fraction of steps to discard as burn-in in the event that the + autocorrelation time cannot be calculated + burn_in_act: float + The number of autocorrelation times to discard as burn-in + + """ + + default_kwargs = dict( + nwalkers=500, + args=[], + kwargs={}, + pool=None, + log_prob0=None, + start=None, + blobs0=None, + iterations=100, + thin=1, + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=False, + pos0=None, + nburn=None, + burn_in_fraction=0.25, + resume=True, + burn_in_act=3, + **kwargs + ): + import zeus + + self.zeus = zeus + + super(Zeus, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs + ) + self.resume = resume + self.pos0 = pos0 + self.nburn = nburn + self.burn_in_fraction = burn_in_fraction + self.burn_in_act = burn_in_act + + signal.signal(signal.SIGTERM, self.checkpoint_and_exit) + signal.signal(signal.SIGINT, self.checkpoint_and_exit) + + def _translate_kwargs(self, kwargs): + if "nwalkers" not in kwargs: + for equiv in self.nwalkers_equiv_kwargs: + if equiv in kwargs: + kwargs["nwalkers"] = kwargs.pop(equiv) + if "iterations" not in kwargs: + if "nsteps" in kwargs: + kwargs["iterations"] = kwargs.pop("nsteps") + + # check if using emcee-style arguments + if "start" not in kwargs: + if "rstate0" in kwargs: + kwargs["start"] = kwargs.pop("rstate0") + if "log_prob0" not in kwargs: + if "lnprob0" in kwargs: + kwargs["log_prob0"] = kwargs.pop("lnprob0") + + if "threads" in kwargs: + if kwargs["threads"] != 1: + logger.warning( + "The 'threads' argument cannot be used for " + "parallelisation. This run will proceed " + "without parallelisation, but consider the use " + "of an appropriate Pool object passed to the " + "'pool' keyword." + ) + kwargs["threads"] = 1 + + @property + def sampler_function_kwargs(self): + keys = ["log_prob0", "start", "blobs0", "iterations", "thin", "progress"] + + function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + + return function_kwargs + + @property + def sampler_init_kwargs(self): + init_kwargs = { + key: value + for key, value in self.kwargs.items() + if key not in self.sampler_function_kwargs + } + + init_kwargs["logprob_fn"] = self.lnpostfn + init_kwargs["ndim"] = self.ndim + + return init_kwargs + + def lnpostfn(self, theta): + log_prior = self.log_prior(theta) + if np.isinf(log_prior): + return -np.inf, [np.nan, np.nan] + else: + log_likelihood = self.log_likelihood(theta) + return log_likelihood + log_prior, [log_likelihood, log_prior] + + @property + def nburn(self): + if type(self.__nburn) in [float, int]: + return int(self.__nburn) + elif self.result.max_autocorrelation_time is None: + return int(self.burn_in_fraction * self.nsteps) + else: + return int(self.burn_in_act * self.result.max_autocorrelation_time) + + @nburn.setter + def nburn(self, nburn): + if isinstance(nburn, (float, int)): + if nburn > self.kwargs["iterations"] - 1: + raise ValueError( + "Number of burn-in samples must be smaller " + "than the total number of iterations" + ) + + self.__nburn = nburn + + @property + def nwalkers(self): + return self.kwargs["nwalkers"] + + @property + def nsteps(self): + return self.kwargs["iterations"] + + @nsteps.setter + def nsteps(self, nsteps): + self.kwargs["iterations"] = nsteps + + @property + def stored_chain(self): + """Read the stored zero-temperature chain data in from disk""" + return np.genfromtxt(self.checkpoint_info.chain_file, names=True) + + @property + def stored_samples(self): + """Returns the samples stored on disk""" + return self.stored_chain[self.search_parameter_keys] + + @property + def stored_loglike(self): + """Returns the log-likelihood stored on disk""" + return self.stored_chain["log_l"] + + @property + def stored_logprior(self): + """Returns the log-prior stored on disk""" + return self.stored_chain["log_p"] + + def _init_chain_file(self): + with open(self.checkpoint_info.chain_file, "w+") as ff: + ff.write( + "walker\t{}\tlog_l\tlog_p\n".format( + "\t".join(self.search_parameter_keys) + ) + ) + + @property + def checkpoint_info(self): + """Defines various things related to checkpointing and storing data + + Returns + ======= + checkpoint_info: named_tuple + An object with attributes `sampler_file`, `chain_file`, and + `chain_template`. The first two give paths to where the sampler and + chain data is stored, the last a formatted-str-template with which + to write the chain data to disk + + """ + out_dir = os.path.join( + self.outdir, "{}_{}".format(self.__class__.__name__.lower(), self.label) + ) + check_directory_exists_and_if_not_mkdir(out_dir) + + chain_file = os.path.join(out_dir, "chain.dat") + sampler_file = os.path.join(out_dir, "sampler.pickle") + chain_template = ( + "{:d}" + "\t{:.9e}" * (len(self.search_parameter_keys) + 2) + "\n" + ) + + CheckpointInfo = namedtuple( + "CheckpointInfo", ["sampler_file", "chain_file", "chain_template"] + ) + + checkpoint_info = CheckpointInfo( + sampler_file=sampler_file, + chain_file=chain_file, + chain_template=chain_template, + ) + + return checkpoint_info + + @property + def sampler_chain(self): + nsteps = self._previous_iterations + return self.sampler.chain[:, :nsteps, :] + + def checkpoint(self): + """Writes a pickle file of the sampler to disk using dill""" + import dill + + logger.info( + "Checkpointing sampler to file {}".format(self.checkpoint_info.sampler_file) + ) + with open(self.checkpoint_info.sampler_file, "wb") as f: + # Overwrites the stored sampler chain with one that is truncated + # to only the completed steps + self.sampler._chain = self.sampler_chain + dill.dump(self._sampler, f) + + def checkpoint_and_exit(self, signum, frame): + logger.info("Received signal {}".format(signum)) + self.checkpoint() + sys.exit() + + def _initialise_sampler(self): + self._sampler = self.zeus.EnsembleSampler(**self.sampler_init_kwargs) + self._init_chain_file() + + @property + def sampler(self): + """Returns the Zeus sampler object + + If, already initialized, returns the stored _sampler value. Otherwise, + first checks if there is a pickle file from which to load. If there is + not, then initialize the sampler and set the initial random draw + + """ + if hasattr(self, "_sampler"): + pass + elif self.resume and os.path.isfile(self.checkpoint_info.sampler_file): + import dill + + logger.info( + "Resuming run from checkpoint file {}".format( + self.checkpoint_info.sampler_file + ) + ) + with open(self.checkpoint_info.sampler_file, "rb") as f: + self._sampler = dill.load(f) + self._set_pos0_for_resume() + else: + self._initialise_sampler() + self._set_pos0() + return self._sampler + + def write_chains_to_file(self, sample): + chain_file = self.checkpoint_info.chain_file + temp_chain_file = chain_file + ".temp" + if os.path.isfile(chain_file): + copyfile(chain_file, temp_chain_file) + + points = np.hstack([sample[0], np.array(sample[2])]) + + with open(temp_chain_file, "a") as ff: + for ii, point in enumerate(points): + ff.write(self.checkpoint_info.chain_template.format(ii, *point)) + shutil.move(temp_chain_file, chain_file) + + @property + def _previous_iterations(self): + """Returns the number of iterations that the sampler has saved + + This is used when loading in a sampler from a pickle file to figure out + how much of the run has already been completed + """ + try: + return len(self.sampler.get_blobs()) + except AttributeError: + return 0 + + def _draw_pos0_from_prior(self): + return np.array( + [self.get_random_draw_from_prior() for _ in range(self.nwalkers)] + ) + + @property + def _pos0_shape(self): + return (self.nwalkers, self.ndim) + + def _set_pos0(self): + if self.pos0 is not None: + logger.debug("Using given initial positions for walkers") + if isinstance(self.pos0, DataFrame): + self.pos0 = self.pos0[self.search_parameter_keys].values + elif type(self.pos0) in (list, np.ndarray): + self.pos0 = np.squeeze(self.pos0) + + if self.pos0.shape != self._pos0_shape: + raise ValueError("Input pos0 should be of shape ndim, nwalkers") + logger.debug("Checking input pos0") + for draw in self.pos0: + self.check_draw(draw) + else: + logger.debug("Generating initial walker positions from prior") + self.pos0 = self._draw_pos0_from_prior() + + def _set_pos0_for_resume(self): + self.pos0 = self.sampler.get_last_sample() + + def run_sampler(self): + sampler_function_kwargs = self.sampler_function_kwargs + iterations = sampler_function_kwargs.pop("iterations") + iterations -= self._previous_iterations + + sampler_function_kwargs["start"] = self.pos0 + + # main iteration loop + for sample in self.sampler.sample( + iterations=iterations, **sampler_function_kwargs + ): + self.write_chains_to_file(sample) + self.checkpoint() + + self.result.sampler_output = np.nan + self.calculate_autocorrelation(self.sampler.chain.reshape((-1, self.ndim))) + self.print_nburn_logging_info() + + self._generate_result() + + self.result.samples = self.sampler.get_chain(flat=True, discard=self.nburn) + self.result.walkers = self.sampler.chain + return self.result + + def _generate_result(self): + self.result.nburn = self.nburn + self.calc_likelihood_count() + if self.result.nburn > self.nsteps: + raise SamplerError( + "The run has finished, but the chain is not burned in: " + "`nburn < nsteps` ({} < {}). Try increasing the " + "number of steps.".format(self.result.nburn, self.nsteps) + ) + blobs = np.array(self.sampler.get_blobs(flat=True, discard=self.nburn)).reshape((-1, 2)) + log_likelihoods, log_priors = blobs.T + self.result.log_likelihood_evaluations = log_likelihoods + self.result.log_prior_evaluations = log_priors + self.result.log_evidence = np.nan + self.result.log_evidence_err = np.nan diff --git a/docs/samplers.txt b/docs/samplers.txt index 8e26550cb1c2d564d8c1c4ed3810890ed214cf98..fad1d62ba2953e46bce4e635c98269c0e8aeecb1 100644 --- a/docs/samplers.txt +++ b/docs/samplers.txt @@ -70,6 +70,7 @@ MCMC samplers - emcee :code:`bilby.core.sampler.emcee.Emcee` - ptemcee :code:`bilby.core.sampler.ptemcee.Ptemcee` - pymc3 :code:`bilby.core.sampler.pymc3.Pymc3` +- zeus :code:`bilby.core.sampler.zeus.Zeus` ------------------- diff --git a/sampler_requirements.txt b/sampler_requirements.txt index 443a96ab2a1ea2f4f7474436437a43d2b4afa0e5..d60093c8b1b3e2abd5af5397f02f1986655f5e04 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -10,3 +10,4 @@ ultranest>=3.0.0 dnest4 nessai>=0.2.3 schwimmbad +zeus-mcmc>=2.3.0 diff --git a/test/core/sampler/zeus_test.py b/test/core/sampler/zeus_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e7bd9289cf5e3115a017c8eeae50e21b99c12ed8 --- /dev/null +++ b/test/core/sampler/zeus_test.py @@ -0,0 +1,64 @@ +import unittest + +from mock import MagicMock + +import bilby + + +class TestZeus(unittest.TestCase): + def setUp(self): + self.likelihood = MagicMock() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) + ) + self.sampler = bilby.core.sampler.Zeus( + self.likelihood, + self.priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=True, + ) + + def tearDown(self): + del self.likelihood + del self.priors + del self.sampler + + def test_default_kwargs(self): + expected = dict( + nwalkers=500, + args=[], + kwargs={}, + pool=None, + log_prob0=None, + start=None, + blobs0=None, + iterations=100, + thin=1, + ) + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs(self): + expected = dict( + nwalkers=100, + args=[], + kwargs={}, + pool=None, + log_prob0=None, + start=None, + blobs0=None, + iterations=100, + thin=1, + ) + for equiv in bilby.core.sampler.base_sampler.MCMCSampler.nwalkers_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs["nwalkers"] + new_kwargs[equiv] = 100 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + +if __name__ == "__main__": + unittest.main()