diff --git a/bilby/core/sampler/__init__.py b/bilby/core/sampler/__init__.py index 4972555cd3a36e381431e16ee766a1594da9063a..48858d6d4bebeb09726dc95ee95ca65424542e9b 100644 --- a/bilby/core/sampler/__init__.py +++ b/bilby/core/sampler/__init__.py @@ -13,6 +13,7 @@ from .dynamic_dynesty import DynamicDynesty from .dynesty import Dynesty from .emcee import Emcee from .kombine import Kombine +from .nessai import Nessai from .nestle import Nestle from .polychord import PyPolyChord from .ptemcee import Ptemcee @@ -26,10 +27,10 @@ from . import proposal IMPLEMENTED_SAMPLERS = { 'cpnest': Cpnest, 'dnest4': DNest4, 'dynamic_dynesty': DynamicDynesty, - 'dynesty': Dynesty, 'emcee': Emcee, 'kombine': Kombine, 'nestle': Nestle, - 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, 'pymc3': Pymc3, - 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord, 'ultranest': Ultranest, - 'fake_sampler': FakeSampler} + 'dynesty': Dynesty, 'emcee': Emcee,'kombine': Kombine, 'nessai': Nessai, + 'nestle': Nestle, 'ptemcee': Ptemcee, 'ptmcmcsampler': PTMCMCSampler, + 'pymc3': Pymc3, 'pymultinest': Pymultinest, 'pypolychord': PyPolyChord, + 'ultranest': Ultranest, 'fake_sampler': FakeSampler} if command_line_args.sampler_help: sampler = command_line_args.sampler_help diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0bc630ad5a1d0f4b9063782829145c4bbcbe8a --- /dev/null +++ b/bilby/core/sampler/nessai.py @@ -0,0 +1,202 @@ +import numpy as np +from pandas import DataFrame + +from .base_sampler import NestedSampler +from ..utils import logger, check_directory_exists_and_if_not_mkdir, load_json + + +class Nessai(NestedSampler): + """bilby wrapper of nessai (https://github.com/mj-will/nessai) + + All positional and keyword arguments passed to `run_sampler` are propogated + to `nessai.flowsampler.FlowSampler` + + See the documentation for an explanation of the different kwargs. + + Documentation: https://nessai.readthedocs.io/ + """ + default_kwargs = dict( + output=None, + nlive=1000, + stopping=0.1, + resume=True, + max_iteration=None, + checkpointing=True, + seed=1234, + acceptance_threshold=0.01, + analytic_priors=False, + maximum_uninformed=1000, + uninformed_proposal=None, + uninformed_proposal_kwargs=None, + flow_class=None, + flow_config=None, + training_frequency=None, + reset_weights=False, + reset_permutations=False, + reset_acceptance=False, + train_on_empty=True, + cooldown=100, + memory=False, + poolsize=None, + drawsize=None, + max_poolsize_scale=10, + update_poolsize=False, + latent_prior='truncated_gaussian', + draw_latent_kwargs=None, + compute_radius_with_all=False, + min_radius=False, + max_radius=50, + check_acceptance=False, + fuzz=1.0, + expansion_fraction=1.0, + rescale_parameters=True, + rescale_bounds=[-1, 1], + update_bounds=False, + boundary_inversion=False, + inversion_type='split', detect_edges=False, + detect_edges_kwargs=None, + reparameterisations=None, + n_pool=None, + max_threads=1, + pytorch_threads=None, + plot=None, + proposal_plots=False + ) + seed_equiv_kwargs = ['sampling_seed'] + + def log_prior(self, theta): + """ + + Parameters + ---------- + theta: list + List of sampled values on a unit interval + + Returns + ------- + float: Joint ln prior probability of theta + + """ + return self.priors.ln_prob(theta, axis=0) + + def run_sampler(self): + from nessai.flowsampler import FlowSampler + from nessai.model import Model as BaseModel + from nessai.livepoint import dict_to_live_points + from nessai.posterior import compute_weights + from nessai.utils import setup_logger + + class Model(BaseModel): + """A wrapper class to pass our log_likelihood and priors into nessai + + Parameters + ---------- + names : list of str + List of parameters to sample + priors : :obj:`bilby.core.prior.PriorDict` + Priors to use for sampling. Needed for the bounds and the + `sample` method. + """ + def __init__(self, names, priors): + self.names = names + self.priors = priors + self._update_bounds() + + @staticmethod + def log_likelihood(x, **kwargs): + """Compute the log likelihood""" + theta = [x[n].item() for n in self.search_parameter_keys] + return self.log_likelihood(theta) + + @staticmethod + def log_prior(x, **kwargs): + """Compute the log prior""" + theta = {n: x[n] for n in self._search_parameter_keys} + return self.log_prior(theta) + + def _update_bounds(self): + self.bounds = {key: [self.priors[key].minimum, self.priors[key].maximum] + for key in self.names} + + def new_point(self, N=1): + """Draw a point from the prior""" + prior_samples = self.priors.sample(size=N) + samples = {n: prior_samples[n] for n in self.names} + return dict_to_live_points(samples) + + def new_point_log_prob(self, x): + """Proposal probability for new the point""" + return self.log_prior(x) + + # Setup the logger for nessai using the same settings as the bilby logger + setup_logger(self.outdir, label=self.label, + log_level=logger.getEffectiveLevel()) + model = Model(self.search_parameter_keys, self.priors) + out = None + while out is None: + try: + out = FlowSampler(model, **self.kwargs) + except TypeError as e: + raise TypeError("Unable to initialise nessai sampler with error: {}".format(e)) + try: + out.run(save=True, plot=self.plot) + except SystemExit as e: + import sys + logger.info("Caught exit code {}, exiting with signal {}".format(e.args[0], self.exit_code)) + sys.exit(self.exit_code) + + # Manually set likelihood evaluations because parallelisation breaks the counter + self.result.num_likelihood_evaluations = out.ns.likelihood_evaluations[-1] + + self.result.posterior = DataFrame(out.posterior_samples) + self.result.nested_samples = DataFrame(out.nested_samples) + self.result.nested_samples.rename( + columns=dict(logL='log_likelihood', logP='log_prior'), inplace=True) + self.result.posterior.rename( + columns=dict(logL='log_likelihood', logP='log_prior'), inplace=True) + _, log_weights = compute_weights(np.array(self.result.nested_samples.log_likelihood), + np.array(out.ns.state.nlive)) + self.result.nested_samples['weights'] = np.exp(log_weights) + self.result.log_evidence = out.ns.log_evidence + self.result.log_evidence_err = np.sqrt(out.ns.information / out.ns.nlive) + + return self.result + + def _translate_kwargs(self, kwargs): + if 'nlive' not in kwargs: + for equiv in self.npoints_equiv_kwargs: + if equiv in kwargs: + kwargs['nlive'] = kwargs.pop(equiv) + if 'n_pool' not in kwargs: + for equiv in self.npool_equiv_kwargs: + if equiv in kwargs: + kwargs['n_pool'] = kwargs.pop(equiv) + if 'seed' not in kwargs: + for equiv in self.seed_equiv_kwargs: + if equiv in kwargs: + kwargs['seed'] = kwargs.pop(equiv) + + def _verify_kwargs_against_default_kwargs(self): + """ + Set the directory where the output will be written + and check resume and checkpoint status. + """ + if 'config_file' in self.kwargs: + d = load_json(self.kwargs['config_file'], None) + self.kwargs.update(d) + self.kwargs.pop('config_file') + + if not self.kwargs['plot']: + self.kwargs['plot'] = self.plot + + if self.kwargs['n_pool'] == 1 and self.kwargs['max_threads'] == 1: + logger.warning('Setting pool to None (n_pool=1 & max_threads=1)') + self.kwargs['n_pool'] = None + + if not self.kwargs['output']: + self.kwargs['output'] = self.outdir + '/{}_nessai/'.format(self.label) + if self.kwargs['output'].endswith('/') is False: + self.kwargs['output'] = '{}/'.format(self.kwargs['output']) + + check_directory_exists_and_if_not_mkdir(self.kwargs['output']) + NestedSampler._verify_kwargs_against_default_kwargs(self) diff --git a/docs/samplers.txt b/docs/samplers.txt index a36471a5a1d1e9b39ab4e57e11454b3ba7057e9b..45adc38adac803604996b1619a1e5767a67c1621 100644 --- a/docs/samplers.txt +++ b/docs/samplers.txt @@ -60,6 +60,7 @@ Nested Samplers - PyPolyChord :code:`bilby.core.sampler.polychord.PyPolyChord` - UltraNest :code:`bilby.core.sampler.ultranest.Ultranest` - DNest4 :code:`bilby.core.sampler.dnest4.DNest4` +- Nessai: code:`bilby.core.sampler.nessai.Nessai` ------------- MCMC samplers diff --git a/sampler_requirements.txt b/sampler_requirements.txt index 08202cc3890a9fa56ea2b84a0f786f7568d515cb..2e7155585bedfbc9630e2a2406dc3c171d95f7b0 100644 --- a/sampler_requirements.txt +++ b/sampler_requirements.txt @@ -9,3 +9,4 @@ pymultinest kombine ultranest>=3.0.0 dnest4 +nessai>=0.2.3 diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py new file mode 100644 index 0000000000000000000000000000000000000000..15b0f26afe179c7ce0b70f6cd234761142f605de --- /dev/null +++ b/test/core/sampler/nessai_test.py @@ -0,0 +1,131 @@ +import unittest + +from mock import MagicMock, patch, mock_open + +import bilby + + +class TestNessai(unittest.TestCase): + maxDiff = None + + def setUp(self): + self.likelihood = MagicMock() + self.priors = bilby.core.prior.PriorDict( + dict(a=bilby.core.prior.Uniform(0, 1), b=bilby.core.prior.Uniform(0, 1)) + ) + self.sampler = bilby.core.sampler.Nessai( + self.likelihood, + self.priors, + outdir="outdir", + label="label", + use_ratio=False, + plot=False, + skip_import_verification=True, + ) + self.expected = dict( + output="outdir/label_nessai/", + nlive=1000, + stopping=0.1, + resume=True, + max_iteration=None, + checkpointing=True, + seed=1234, + acceptance_threshold=0.01, + analytic_priors=False, + maximum_uninformed=1000, + uninformed_proposal=None, + uninformed_proposal_kwargs=None, + flow_class=None, + flow_config=None, + training_frequency=None, + reset_weights=False, + reset_permutations=False, + reset_acceptance=False, + train_on_empty=True, + cooldown=100, + memory=False, + poolsize=None, + drawsize=None, + max_poolsize_scale=10, + update_poolsize=False, + latent_prior='truncated_gaussian', + draw_latent_kwargs=None, + compute_radius_with_all=False, + min_radius=False, + max_radius=50, + check_acceptance=False, + fuzz=1.0, + expansion_fraction=1.0, + rescale_parameters=True, + rescale_bounds=[-1, 1], + update_bounds=False, + boundary_inversion=False, + inversion_type='split', detect_edges=False, + detect_edges_kwargs=None, + reparameterisations=None, + n_pool=None, + max_threads=1, + pytorch_threads=None, + plot=False, + proposal_plots=False + ) + + def tearDown(self): + del self.likelihood + del self.priors + del self.sampler + del self.expected + + def test_default_kwargs(self): + expected = self.expected.copy() + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs_nlive(self): + expected = self.expected.copy() + for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs["nlive"] + new_kwargs[equiv] = 1000 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs_npool(self): + expected = self.expected.copy() + expected["n_pool"] = None + for equiv in bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs["n_pool"] + new_kwargs[equiv] = None + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_translate_kwargs_seed(self): + expected = self.expected.copy() + expected["seed"] = 150914 + for equiv in bilby.core.sampler.nessai.Nessai.seed_equiv_kwargs: + new_kwargs = self.sampler.kwargs.copy() + del new_kwargs["seed"] + new_kwargs[equiv] = 150914 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + def test_npool_max_threads(self): + expected = self.expected.copy() + expected["n_pool"] = None + new_kwargs = self.sampler.kwargs.copy() + new_kwargs["n_pool"] = 1 + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + @patch("builtins.open", mock_open(read_data='{"nlive": 2000}')) + def test_update_from_config_file(self): + expected = self.expected.copy() + expected["nlive"] = 2000 + new_kwargs = self.expected.copy() + new_kwargs["config_file"] = "config_file.json" + self.sampler.kwargs = new_kwargs + self.assertDictEqual(expected, self.sampler.kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/integration/sampler_run_test.py b/test/integration/sampler_run_test.py index d5315cd00ba25db25f0904ecd0f89404fa1bed7a..1e263986b7e513773628f1ea733f28df70d1f51b 100644 --- a/test/integration/sampler_run_test.py +++ b/test/integration/sampler_run_test.py @@ -105,6 +105,17 @@ class TestRunningSamplers(unittest.TestCase): save=False, ) + def test_run_nessai(self): + _ = bilby.run_sampler( + likelihood=self.likelihood, + priors=self.priors, + sampler="nessai", + nlive=100, + poolsize=1000, + max_iteration=1000, + save=False, + ) + def test_run_pypolychord(self): pytest.importorskip("pypolychord") _ = bilby.run_sampler(