diff --git a/README.rst b/README.rst index 37b98613e0fb68572f378cfb9340b53856b3ac93..70bc2efa539a21c4614e3245a4ad7540f4d4bd6d 100644 --- a/README.rst +++ b/README.rst @@ -53,6 +53,7 @@ as requested in their associated documentation. * `pymultinest <https://github.com/JohannesBuchner/PyMultiNest>`__ * `cpnest <https://github.com/johnveitch/cpnest>`__ * `emcee <https://github.com/dfm/emcee>`__ +* `nessai <https://github.com/mj-will/nessai>`_ * `ptemcee <https://github.com/willvousden/ptemcee>`__ * `ptmcmcsampler <https://github.com/jellis18/PTMCMCSampler>`__ * `pypolychord <https://github.com/PolyChord/PolyChordLite>`__ diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index d0d05037031383ff9a22a08898856e06a6ddbf8d..a0d3e72ff5dbcc7224e6006fcda38ad20cf2f370 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -1,10 +1,12 @@ import os +import sys import numpy as np from pandas import DataFrame +from scipy.special import logsumexp from ..utils import check_directory_exists_and_if_not_mkdir, load_json, logger -from .base_sampler import NestedSampler +from .base_sampler import NestedSampler, signal_wrapper class Nessai(NestedSampler): @@ -19,8 +21,22 @@ class Nessai(NestedSampler): """ _default_kwargs = None + _run_kwargs_list = None sampling_seed_key = "seed" + @property + def run_kwargs_list(self): + """List of kwargs used in the run method of :code:`FlowSampler`""" + if not self._run_kwargs_list: + from nessai.utils.bilbyutils import get_run_kwargs_list + + self._run_kwargs_list = get_run_kwargs_list() + ignored_kwargs = ["save"] + for ik in ignored_kwargs: + if ik in self._run_kwargs_list: + self._run_kwargs_list.remove(ik) + return self._run_kwargs_list + @property def default_kwargs(self): """Default kwargs for nessai. @@ -28,32 +44,38 @@ class Nessai(NestedSampler): Retrieves default values from nessai directly and then includes any bilby specific defaults. This avoids the need to update bilby when the defaults change or new kwargs are added to nessai. + + Includes the following kwargs that are specific to bilby: + + - :code:`nessai_log_level`: allows setting the logging level in nessai + - :code:`nessai_logging_stream`: allows setting the logging stream + - :code:`nessai_plot`: allows toggling the plotting in FlowSampler.run """ if not self._default_kwargs: - from inspect import signature - - from nessai.flowsampler import FlowSampler - from nessai.nestedsampler import NestedSampler - from nessai.proposal import AugmentedFlowProposal, FlowProposal - - kwargs = {} - classes = [ - AugmentedFlowProposal, - FlowProposal, - NestedSampler, - FlowSampler, - ] - for c in classes: - kwargs.update( - { - k: v.default - for k, v in signature(c).parameters.items() - if v.default is not v.empty - } - ) + from nessai.utils.bilbyutils import get_all_kwargs + + kwargs = get_all_kwargs() + # Defaults for bilby that will override nessai defaults - bilby_defaults = dict(output=None, exit_code=self.exit_code) + bilby_defaults = dict( + output=None, + exit_code=self.exit_code, + nessai_log_level=None, + nessai_logging_stream="stdout", + nessai_plot=True, + plot_posterior=False, # bilby already produces a posterior plot + log_on_iteration=False, # Use periodic logging by default + logging_interval=60, # Log every 60 seconds + ) kwargs.update(bilby_defaults) + # Kwargs that cannot be set in bilby + remove = [ + "save", + "signal_handling", + ] + for k in remove: + if k in kwargs: + kwargs.pop(k) self._default_kwargs = kwargs return self._default_kwargs @@ -72,12 +94,10 @@ class Nessai(NestedSampler): """ return self.priors.ln_prob(theta, axis=0) - def run_sampler(self): - from nessai.flowsampler import FlowSampler - from nessai.livepoint import dict_to_live_points, live_points_to_array + def get_nessai_model(self): + """Get the model for nessai.""" + from nessai.livepoint import dict_to_live_points from nessai.model import Model as BaseModel - from nessai.posterior import compute_weights - from nessai.utils import setup_logger class Model(BaseModel): """A wrapper class to pass our log_likelihood and priors into nessai @@ -124,47 +144,115 @@ class Nessai(NestedSampler): """Proposal probability for new the point""" return self.log_prior(x) - # Setup the logger for nessai using the same settings as the bilby logger - setup_logger( - self.outdir, label=self.label, log_level=logger.getEffectiveLevel() - ) + @staticmethod + def from_unit_hypercube(x): + """Map samples from the unit hypercube to the prior.""" + theta = {} + for n in self._search_parameter_keys: + theta[n] = self.priors[n].rescale(x[n]) + return dict_to_live_points(theta) + + @staticmethod + def to_unit_hypercube(x): + """Map samples from the prior to the unit hypercube.""" + theta = {n: x[n] for n in self._search_parameter_keys} + return dict_to_live_points(self.priors.cdf(theta)) + model = Model(self.search_parameter_keys, self.priors) - try: - out = FlowSampler(model, **self.kwargs) - out.run(save=True, plot=self.plot) - except TypeError as e: - raise TypeError(f"Unable to initialise nessai sampler with error: {e}") - except (SystemExit, KeyboardInterrupt) as e: - import sys - - logger.info( - f"Caught {type(e).__name__} with args {e.args}, " - f"exiting with signal {self.exit_code}" - ) - sys.exit(self.exit_code) + return model + + def split_kwargs(self): + """Split kwargs into configuration and run time kwargs""" + kwargs = self.kwargs.copy() + run_kwargs = {} + for k in self.run_kwargs_list: + run_kwargs[k] = kwargs.pop(k) + run_kwargs["plot"] = kwargs.pop("nessai_plot") + return kwargs, run_kwargs + + def get_posterior_weights(self): + """Get the posterior weights for the nested samples""" + from nessai.posterior import compute_weights + + _, log_weights = compute_weights( + np.array(self.fs.nested_samples["logL"]), + np.array(self.fs.ns.state.nlive), + ) + w = np.exp(log_weights - logsumexp(log_weights)) + return w + + def get_nested_samples(self): + """Get the nested samples dataframe""" + ns = DataFrame(self.fs.nested_samples) + ns.rename( + columns=dict(logL="log_likelihood", logP="log_prior", it="iteration"), + inplace=True, + ) + return ns + + def update_result(self): + """Update the result object.""" + from nessai.livepoint import live_points_to_array # Manually set likelihood evaluations because parallelisation breaks the counter - self.result.num_likelihood_evaluations = out.ns.likelihood_evaluations[-1] + self.result.num_likelihood_evaluations = self.fs.ns.total_likelihood_evaluations self.result.samples = live_points_to_array( - out.posterior_samples, self.search_parameter_keys + self.fs.posterior_samples, self.search_parameter_keys ) - self.result.log_likelihood_evaluations = out.posterior_samples["logL"] - self.result.nested_samples = DataFrame(out.nested_samples) - self.result.nested_samples.rename( - columns=dict(logL="log_likelihood", logP="log_prior"), inplace=True + self.result.log_likelihood_evaluations = self.fs.posterior_samples["logL"] + self.result.nested_samples = self.get_nested_samples() + self.result.nested_samples["weights"] = self.get_posterior_weights() + self.result.log_evidence = self.fs.log_evidence + self.result.log_evidence_err = self.fs.log_evidence_error + + @signal_wrapper + def run_sampler(self): + """Run the sampler. + + Nessai is designed to be ran in two stages, initialise the sampler + and then call the run method with additional configuration. This means + there are effectively two sets of keyword arguments: one for + initializing the sampler and the other for the run function. + """ + from nessai.flowsampler import FlowSampler + from nessai.utils import setup_logger + + kwargs, run_kwargs = self.split_kwargs() + + # Setup the logger for nessai, use nessai_log_level if specified, else use + # the level of the bilby logger. + nessai_log_level = kwargs.pop("nessai_log_level") + if nessai_log_level is None or nessai_log_level == "bilby": + nessai_log_level = logger.getEffectiveLevel() + nessai_logging_stream = kwargs.pop("nessai_logging_stream") + + setup_logger( + self.outdir, + label=self.label, + log_level=nessai_log_level, + stream=nessai_logging_stream, ) - _, log_weights = compute_weights( - np.array(self.result.nested_samples.log_likelihood), - np.array(out.ns.state.nlive), + + # Get the nessai model + model = self.get_nessai_model() + + # Configure the sampler + self.fs = FlowSampler( + model, + signal_handling=False, # Disable signal handling so it can be handled by bilby + **kwargs, ) - self.result.nested_samples["weights"] = np.exp(log_weights) - self.result.log_evidence = out.ns.log_evidence - self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive) + # Run the sampler + self.fs.run(**run_kwargs) + + # Update the result + self.update_result() return self.result def _translate_kwargs(self, kwargs): + """Translate the keyword arguments""" super()._translate_kwargs(kwargs) if "nlive" not in kwargs: for equiv in self.npoints_equiv_kwargs: @@ -178,10 +266,7 @@ class Nessai(NestedSampler): kwargs["n_pool"] = self._npool def _verify_kwargs_against_default_kwargs(self): - """ - Set the directory where the output will be written - and check resume and checkpoint status. - """ + """Verify the keyword arguments""" if "config_file" in self.kwargs: d = load_json(self.kwargs["config_file"], None) self.kwargs.update(d) @@ -190,10 +275,6 @@ class Nessai(NestedSampler): if not self.kwargs["plot"]: self.kwargs["plot"] = self.plot - if self.kwargs["n_pool"] == 1 and self.kwargs["max_threads"] == 1: - logger.warning("Setting pool to None (n_pool=1 & max_threads=1)") - self.kwargs["n_pool"] = None - if not self.kwargs["output"]: self.kwargs["output"] = os.path.join( self.outdir, f"{self.label}_nessai", "" @@ -202,5 +283,21 @@ class Nessai(NestedSampler): check_directory_exists_and_if_not_mkdir(self.kwargs["output"]) NestedSampler._verify_kwargs_against_default_kwargs(self) + def write_current_state(self): + """Write the current state of the sampler""" + self.fs.ns.checkpoint() + + def write_current_state_and_exit(self, signum=None, frame=None): + """ + Overwrites the base class to make sure that :code:`Nessai` terminates + properly. + """ + if hasattr(self, "fs"): + self.fs.terminate_run(code=signum) + else: + logger.warning("Sampler is not initialized") + self._log_interruption(signum=signum) + sys.exit(self.exit_code) + def _setup_pool(self): pass diff --git a/sampler_requirements.txt b/sampler_requirements.txt index d6ed8e98ebce0797864ea463672eaf04ac600fd0..29f38dc1973f870ce1556b08889fc424bcb753b9 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -9,6 +9,6 @@ pymultinest kombine ultranest>=3.0.0 dnest4 -nessai>=0.2.3 +nessai>=0.7.0 schwimmbad zeus-mcmc>=2.3.0 diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index cbb084735ec50274b45d2dd629772c92d0d3daed..0cac7a45b24e9174336ed454e11908fd0e0e6555 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -21,9 +21,9 @@ class TestNessai(unittest.TestCase): plot=False, skip_import_verification=True, sampling_seed=150914, - npool=None, # TODO: remove when support for nessai<0.7.0 is dropped ) self.expected = self.sampler.default_kwargs + self.expected["n_pool"] = 1 # Because npool=1 by default self.expected['output'] = 'outdir/label_nessai/' self.expected['seed'] = 150914 @@ -48,28 +48,31 @@ class TestNessai(unittest.TestCase): def test_translate_kwargs_npool(self): expected = self.expected.copy() - expected["n_pool"] = None + expected["n_pool"] = 2 for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() del new_kwargs["n_pool"] - new_kwargs[equiv] = None + new_kwargs[equiv] = 2 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) - def test_translate_kwargs_seed(self): - assert self.expected["seed"] == 150914 + def test_split_kwargs(self): + kwargs, run_kwargs = self.sampler.split_kwargs() + assert "save" not in run_kwargs + assert "plot" in run_kwargs - def test_npool_max_threads(self): - # TODO: remove when support for nessai<0.7.0 is dropped + def test_translate_kwargs_no_npool(self): expected = self.expected.copy() - expected["n_pool"] = None - expected["max_threads"] = 1 + expected["n_pool"] = 3 new_kwargs = self.sampler.kwargs.copy() - new_kwargs["n_pool"] = 1 - new_kwargs["max_threads"] = 1 + del new_kwargs["n_pool"] + self.sampler._npool = 3 self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) + def test_translate_kwargs_seed(self): + assert self.expected["seed"] == 150914 + @patch("builtins.open", mock_open(read_data='{"nlive": 4000}')) def test_update_from_config_file(self): expected = self.expected.copy() diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index cdf549ccd519737916e97422f00b5de92feaf0ea..f9304971da4b5d91bc3475a1b3280576f96ef525 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -41,9 +41,8 @@ _sampler_kwargs = dict( kombine=dict(iterations=200, nwalkers=10, autoburnin=False), nessai=dict( nlive=100, - poolsize=1000, - max_iteration=1000, - max_threads=3, + poolsize=100, + max_iteration=500, ), nestle=dict(nlive=100), ptemcee=dict( @@ -159,11 +158,6 @@ class TestRunningSamplers(unittest.TestCase): pytest.skip(f"{sampler} cannot be parallelized") if sys.version_info.minor == 8 and sampler.lower == "cpnest": pytest.skip("Pool interrupting broken for cpnest with py3.8") - if sampler.lower() == "nessai" and pool_size > 1: - pytest.skip( - "Interrupting with a pool is failing in pytest. " - "Likely due to interactions with the signal handling in nessai." - ) pid = os.getpid() print(sampler)