diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 32f63c5f9fd05eebea0b254bc10b57dc63994489..686c2290bcc5896403b5fd811c6661d268af26ee 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -18,6 +18,7 @@ from .ptemcee import Ptemcee from .ptmcmc import PTMCMCSampler from .pymc3 import Pymc3 from .pymultinest import Pymultinest +from .ultranest import Ultranest from .fake_sampler import FakeSampler from . import proposal @@ -25,7 +26,8 @@ IMPLEMENTED_SAMPLERS = { 'cpnest': Cpnest, 'dynamic_dynesty': DynamicDynesty, 'dynesty': Dynesty, 'emcee': Emcee, 'kombine': Kombine, 'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, 'pymc3': Pymc3, 'pymultinest': Pymultinest, - 'pypolychord': PyPolyChord, 'fake_sampler': FakeSampler} + 'pypolychord': PyPolyChord, 'ultranest': Ultranest, + 'fake_sampler': FakeSampler} if command_line_args.sampler_help: sampler = command_line_args.sampler_help diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 197a6de7653c15f35751295082a1e7ed7826c232..a5ea48ebcbb58079aa227d794af8125a580e687e 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -535,7 +535,7 @@ class Sampler(object): class NestedSampler(Sampler): - npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints', 'npoint', 'Nlive'] + npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints', 'npoint', 'Nlive', 'num_live_points'] walks_equiv_kwargs = ['walks', 'steps', 'nmcmc'] def reorder_loglikelihoods(self, unsorted_loglikelihoods, unsorted_samples, diff --git a/bilby/core/sampler/ultranest.py b/bilby/core/sampler/ultranest.py new file mode 100644 index 0000000000000000000000000000000000000000..50dd4fe710cf02bb9429b4a798f79226dc26e34c --- /dev/null +++ b/bilby/core/sampler/ultranest.py @@ -0,0 +1,326 @@ +from __future__ import absolute_import + +import os +import shutil +import signal +import tempfile + +import numpy as np +from pandas import DataFrame + +from ..utils import check_directory_exists_and_if_not_mkdir, logger +from .base_sampler import NestedSampler + + +class Ultranest(NestedSampler): + """ + bilby wrapper of ultranest + (https://johannesbuchner.github.io/UltraNest/index.html) + + All positional and keyword arguments (i.e., the args and kwargs) passed to + `run_sampler` will be propagated to `ultranest.ReactiveNestedSampler.run` + or `ultranest.NestedSampler.run`, see documentation for those classes for + further help. Under Other Parameters, we list commonly used kwargs and the + bilby defaults. If the number of live points is specified the + `ultranest.NestedSampler` will be used, otherwise the + `ultranest.ReactiveNestedSampler` will be used. + + Other Parameters + ---------------- + num_live_points: int + The number of live points, note this can also equivalently be given as + one of [nlive, nlives, n_live_points, num_live_points]. If not given + then the `ultranest.ReactiveNestedSampler` will be used, which does not + require the number of live points to be specified. + show_status: Bool + If true, print information information about the convergence during + resume: bool + If true, resume run from checkpoint (if available) + step_sampler: + An UltraNest step sampler object. This defaults to None, so the default + stepping behaviour is used. + """ + + default_kwargs = dict( + resume=True, + show_status=True, + num_live_points=None, + wrapped_params=None, + log_dir=None, + derived_param_names=[], + run_num=None, + vectorized=False, + num_test_samples=2, + draw_multiple=True, + num_bootstraps=30, + update_interval_iter=None, + update_interval_ncall=None, + log_interval=None, + dlogz=None, + max_iters=None, + update_interval_iter_fraction=0.2, + viz_callback="auto", + dKL=0.5, + frac_remain=0.01, + Lepsilon=0.001, + min_ess=400, + max_ncalls=None, + max_num_improvement_loops=-1, + min_num_live_points=400, + cluster_num_live_points=40, + step_sampler=None, + ) + + def __init__( + self, + likelihood, + priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + exit_code=77, + skip_import_verification=False, + **kwargs, + ): + super(Ultranest, self).__init__( + likelihood=likelihood, + priors=priors, + outdir=outdir, + label=label, + use_ratio=use_ratio, + plot=plot, + skip_import_verification=skip_import_verification, + **kwargs, + ) + self._apply_ultranest_boundaries() + self.exit_code = exit_code + + signal.signal(signal.SIGTERM, self.write_current_state_and_exit) + signal.signal(signal.SIGINT, self.write_current_state_and_exit) + signal.signal(signal.SIGALRM, self.write_current_state_and_exit) + + def _translate_kwargs(self, kwargs): + if "num_live_points" not in kwargs: + for equiv in self.npoints_equiv_kwargs: + if equiv in kwargs: + kwargs["num_live_points"] = kwargs.pop(equiv) + + if "verbose" in kwargs and "show_status" not in kwargs: + kwargs["show_status"] = kwargs.pop("verbose") + + def _verify_kwargs_against_default_kwargs(self): + """ Check the kwargs """ + + self.outputfiles_basename = self.kwargs.pop("log_dir", None) + + NestedSampler._verify_kwargs_against_default_kwargs(self) + + def _apply_ultranest_boundaries(self): + if ( + self.kwargs["wrapped_params"] is None + or len(self.kwargs.get("wrapped_params", [])) == 0 + ): + self.kwargs["wrapped_params"] = [] + for param, value in self.priors.items(): + if value.boundary == "periodic": + self.kwargs["wrapped_params"].append(1) + else: + self.kwargs["wrapped_params"].append(0) + + @property + def outputfiles_basename(self): + return self._outputfiles_basename + + @outputfiles_basename.setter + def outputfiles_basename(self, outputfiles_basename): + if outputfiles_basename is None: + outputfiles_basename = "{}/ultra_{}".format(self.outdir, self.label) + if outputfiles_basename.endswith("/") is True: + outputfiles_basename = outputfiles_basename.rstrip("/") + check_directory_exists_and_if_not_mkdir(self.outdir) + self._outputfiles_basename = outputfiles_basename + + @property + def temporary_outputfiles_basename(self): + return self._temporary_outputfiles_basename + + @temporary_outputfiles_basename.setter + def temporary_outputfiles_basename(self, temporary_outputfiles_basename): + if temporary_outputfiles_basename.endswith("/") is False: + temporary_outputfiles_basename = "{}/".format( + temporary_outputfiles_basename + ) + self._temporary_outputfiles_basename = temporary_outputfiles_basename + if os.path.exists(self.outputfiles_basename): + shutil.copytree( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) + if os.path.islink(self.outputfiles_basename): + os.unlink(self.outputfiles_basename) + else: + shutil.rmtree(self.outputfiles_basename) + + def write_current_state_and_exit(self, signum=None, frame=None): + """ Write current state and exit on exit_code """ + logger.info( + "Run interrupted by signal {}: checkpoint and exit on {}".format( + signum, self.exit_code + ) + ) + # self.copy_temporary_directory_to_proper_path() + os._exit(self.exit_code) + + def copy_temporary_directory_to_proper_path(self): + logger.info( + "Overwriting {} with {}".format( + self.outputfiles_basename, self.temporary_outputfiles_basename + ) + ) + + # First remove anything in the outputfiles_basename for overwriting + if os.path.exists(self.outputfiles_basename): + if os.path.islink(self.outputfiles_basename): + os.unlink(self.outputfiles_basename) + else: + shutil.rmtree(self.outputfiles_basename, ignore_errors=True) + + shutil.copytree(self.temporary_outputfiles_basename, self.outputfiles_basename) + + @property + def sampler_function_kwargs(self): + if self.kwargs.get("num_live_points", None) is not None: + keys = [ + "update_interval_iter", + "update_interval_ncall", + "log_interval", + "dlogz", + "max_iters", + ] + else: + keys = [ + "update_interval_iter_fraction", + "update_interval_ncall", + "log_interval", + "show_status", + "viz_callback", + "dlogz", + "dKL", + "frac_remain", + "Lepsilon", + "min_ess", + "max_iters", + "max_ncalls", + "max_num_improvement_loops", + "min_num_live_points", + "cluster_num_live_points", + ] + + function_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + + return function_kwargs + + @property + def sampler_init_kwargs(self): + keys = [ + "derived_param_names", + "resume", + "run_num", + "vectorized", + "log_dir", + "wrapped_params", + ] + if self.kwargs.get("num_live_points", None) is not None: + keys += ["num_live_points"] + else: + keys += ["num_test_samples", "draw_multiple", "num_bootstraps"] + + init_kwargs = {key: self.kwargs[key] for key in keys if key in self.kwargs} + + return init_kwargs + + def run_sampler(self): + import ultranest + import ultranest.stepsampler + + if self.kwargs["dlogz"] is None: + # remove dlogz, so ultranest defaults (which are different for + # NestedSampler and ReactiveNestedSampler) are used + self.kwargs.pop("dlogz") + + self._verify_kwargs_against_default_kwargs() + + stepsampler = self.kwargs.pop("step_sampler", None) + + temporary_outputfiles_basename = tempfile.TemporaryDirectory().name + self.temporary_outputfiles_basename = temporary_outputfiles_basename + logger.info("Using temporary file {}".format(temporary_outputfiles_basename)) + + check_directory_exists_and_if_not_mkdir(temporary_outputfiles_basename) + self.kwargs["log_dir"] = self.temporary_outputfiles_basename + + # Symlink the temporary directory with the target directory: ensures data is stored on exit + os.symlink( + os.path.abspath(self.temporary_outputfiles_basename), + os.path.abspath(self.outputfiles_basename), + target_is_directory=True, + ) + + # use reactive nested sampler when no live points are given + if self.kwargs.get("num_live_points", None) is not None: + integrator = ultranest.integrator.NestedSampler + else: + integrator = ultranest.integrator.ReactiveNestedSampler + + sampler = integrator( + self.search_parameter_keys, + self.log_likelihood, + transform=self.prior_transform, + **self.sampler_init_kwargs, + ) + + if stepsampler is not None: + if isinstance(stepsampler, ultranest.stepsampler.StepSampler): + sampler.stepsampler = stepsampler + else: + logger.warning( + "The supplied step sampler is not the correct type. " + "The default step sampling will be used instead." + ) + + results = sampler.run(**self.sampler_function_kwargs) + + self.copy_temporary_directory_to_proper_path() + + # Clean up + shutil.rmtree(temporary_outputfiles_basename) + + self._generate_result(results) + self.calc_likelihood_count() + + return self.result + + def _generate_result(self, out): + # extract results (samples stored in "v" will change to "points", + # weights stored in "w" will change to "weights") + datakey = "v" if "v" in out["weighted_samples"] else "points" + weightskey = "w" if "w" in out["weighted_samples"] else "weights" + data = np.array(out["weighted_samples"][datakey]) + weights = np.array(out["weighted_samples"][weightskey]) + + scaledweights = weights / weights.max() + mask = np.random.rand(len(scaledweights)) < scaledweights + + nested_samples = DataFrame(data, columns=self.search_parameter_keys) + nested_samples["weights"] = weights + nested_samples["log_likelihood"] = out["weighted_samples"]["L"] + self.result.log_likelihood_evaluations = np.array(out["weighted_samples"]["L"])[ + mask + ] + self.result.sampler_output = out + self.result.samples = data[mask, :] + self.result.nested_samples = nested_samples + self.result.log_evidence = out["logz"] + self.result.log_evidence_err = out["logzerr"] + + self.result.outputfiles_basename = self.outputfiles_basename diff --git a/docs/samplers.txt b/docs/samplers.txt index d6ea9ee15dd00ae4f3f9be821c44e549d8b010b7..fbbeff4c75d4e84c63b18521b3fd295886fad09c 100644 --- a/docs/samplers.txt +++ b/docs/samplers.txt @@ -84,6 +84,10 @@ PyPolyChord .. autoclass:: bilby.core.sampler.polychord.PyPolyChord +UltraNest +========= + +.. autoclass:: bilby.core.sampler.ultranest.Ultranest ------------- MCMC samplers diff --git a/sampler_requirements.txt b/sampler_requirements.txt index a6196d0e6aa04fecf495c6dcc290c03e27f83652..be75cbcf66d1fb9ed2580607f3a1bdc7543b2c2c 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -6,4 +6,5 @@ ptemcee pymc3==3.6; python_version <= '2.7' pymc3>=3.6; python_version > '3.4' pymultinest -kombine \ No newline at end of file +kombine +ultranest diff --git a/test/sampler_test.py b/test/sampler_test.py index f5399770033b2301bc5b6784a11e7225706134c4..0140b4f7b8a5c5d6d736adc1350e30ae25cf8f41 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -795,6 +795,100 @@ class TestPymultinest(unittest.TestCase): self.assertDictEqual(expected, self.sampler.kwargs) +class TestUltranest(unittest.TestCase): + + def setUp(self): + self.maxDiff = None + self.likelihood = MagicMock() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), + b=bilby.core.prior.Uniform(0, 1))) + self.priors["a"] = bilby.core.prior.Prior(boundary="periodic") + self.priors["b"] = bilby.core.prior.Prior(boundary="reflective") + self.sampler = bilby.core.sampler.Ultranest(self.likelihood, self.priors, + outdir="outdir", label="label", + use_ratio=False, plot=False, + skip_import_verification=True) + + def tearDown(self): + del self.likelihood + del self.priors + del self.sampler + shutil.rmtree("outdir") + + def test_default_kwargs(self): + expected = dict( + resume=True, + show_status=True, + num_live_points=None, + wrapped_params=None, + derived_param_names=None, + run_num=None, + vectorized=False, + num_test_samples=2, + draw_multiple=True, + num_bootstraps=30, + update_interval_iter=None, + update_interval_ncall=None, + log_interval=None, + dlogz=None, + max_iters=None, + update_interval_iter_fraction=0.2, + viz_callback="auto", + dKL=0.5, + frac_remain=0.01, + Lepsilon=0.001, + min_ess=400, + max_ncalls=None, + max_num_improvement_loops=-1, + min_num_live_points=400, + cluster_num_live_points=40, + step_sampler=None, + ) + self.assertListEqual([1, 0], self.sampler.kwargs["wrapped_params"]) # Check this separately + self.sampler.kwargs["wrapped_params"] = None # The dict comparison can't handle lists + self.sampler.kwargs["derived_param_names"] = None + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs(self): + expected = dict( + resume=True, + show_status=True, + num_live_points=123, + wrapped_params=None, + derived_param_names=None, + run_num=None, + vectorized=False, + num_test_samples=2, + draw_multiple=True, + num_bootstraps=30, + update_interval_iter=None, + update_interval_ncall=None, + log_interval=None, + dlogz=None, + max_iters=None, + update_interval_iter_fraction=0.2, + viz_callback="auto", + dKL=0.5, + frac_remain=0.01, + Lepsilon=0.001, + min_ess=400, + max_ncalls=None, + max_num_improvement_loops=-1, + min_num_live_points=400, + cluster_num_live_points=40, + step_sampler=None, + ) + for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs['num_live_points'] + new_kwargs['wrapped_params'] = None # The dict comparison can't handle lists + new_kwargs["derived_param_names"] = None + new_kwargs[equiv] = 123 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + class TestRunningSamplers(unittest.TestCase): def setUp(self): np.random.seed(42) @@ -937,6 +1031,19 @@ class TestRunningSamplers(unittest.TestCase): save=False, ) + def test_run_ultranest(self): + # run using NestedSampler (with nlive specified) + _ = bilby.run_sampler( + likelihood=self.likelihood, priors=self.priors, + sampler="ultranest", nlive=100, save=False, + ) + + # run using ReactiveNestedSampler (with no nlive given) + _ = bilby.run_sampler( + likelihood=self.likelihood, priors=self.priors, + sampler='ultranest', save=False, + ) + if __name__ == "__main__": unittest.main()