diff --git a/bilby/core/sampler/nessai.py b/bilby/core/sampler/nessai.py index 0ad68945f9f9b0e27913e32e4c735be3c040a073..5e3ef364bb7a49cf90496fdf0afae2480a785789 100644 --- a/bilby/core/sampler/nessai.py +++ b/bilby/core/sampler/nessai.py @@ -1,4 +1,5 @@ import numpy as np +import os from pandas import DataFrame from .base_sampler import NestedSampler @@ -15,55 +16,42 @@ class Nessai(NestedSampler): Documentation: https://nessai.readthedocs.io/ """ - default_kwargs = dict( - output=None, - nlive=1000, - stopping=0.1, - resume=True, - max_iteration=None, - checkpointing=True, - seed=1234, - acceptance_threshold=0.01, - analytic_priors=False, - maximum_uninformed=1000, - uninformed_proposal=None, - uninformed_proposal_kwargs=None, - flow_class=None, - flow_config=None, - training_frequency=None, - reset_weights=False, - reset_permutations=False, - reset_acceptance=False, - train_on_empty=True, - cooldown=100, - memory=False, - poolsize=None, - drawsize=None, - max_poolsize_scale=10, - update_poolsize=False, - latent_prior='truncated_gaussian', - draw_latent_kwargs=None, - compute_radius_with_all=False, - min_radius=False, - max_radius=50, - check_acceptance=False, - fuzz=1.0, - expansion_fraction=1.0, - rescale_parameters=True, - rescale_bounds=[-1, 1], - update_bounds=False, - boundary_inversion=False, - inversion_type='split', detect_edges=False, - detect_edges_kwargs=None, - reparameterisations=None, - n_pool=None, - max_threads=1, - pytorch_threads=None, - plot=None, - proposal_plots=False - ) + _default_kwargs = None seed_equiv_kwargs = ['sampling_seed'] + @property + def default_kwargs(self): + """Default kwargs for nessai. + + Retrieves default values from nessai directly and then includes any + bilby specific defaults. This avoids the need to update bilby when the + defaults change or new kwargs are added to nessai. + """ + if not self._default_kwargs: + from inspect import signature + from nessai.flowsampler import FlowSampler + from nessai.nestedsampler import NestedSampler + from nessai.proposal import AugmentedFlowProposal, FlowProposal + + kwargs = {} + classes = [ + AugmentedFlowProposal, + FlowProposal, + NestedSampler, + FlowSampler, + ] + for c in classes: + kwargs.update( + {k: v.default for k, v in signature(c).parameters.items() if v.default is not v.empty} + ) + # Defaults for bilby that will override nessai defaults + bilby_defaults = dict( + output=None, + ) + kwargs.update(bilby_defaults) + self._default_kwargs = kwargs + return self._default_kwargs + def log_prior(self, theta): """ @@ -194,9 +182,9 @@ class Nessai(NestedSampler): self.kwargs['n_pool'] = None if not self.kwargs['output']: - self.kwargs['output'] = self.outdir + '/{}_nessai/'.format(self.label) - if self.kwargs['output'].endswith('/') is False: - self.kwargs['output'] = '{}/'.format(self.kwargs['output']) + self.kwargs['output'] = os.path.join( + self.outdir, '{}_nessai'.format(self.label), '' + ) check_directory_exists_and_if_not_mkdir(self.kwargs['output']) NestedSampler._verify_kwargs_against_default_kwargs(self) diff --git a/test/core/sampler/nessai_test.py b/test/core/sampler/nessai_test.py index 15b0f26afe179c7ce0b70f6cd234761142f605de..6f902f71d8f8e43a53c824345356bdd389eaa922 100644 --- a/test/core/sampler/nessai_test.py +++ b/test/core/sampler/nessai_test.py @@ -22,53 +22,8 @@ class TestNessai(unittest.TestCase): plot=False, skip_import_verification=True, ) - self.expected = dict( - output="outdir/label_nessai/", - nlive=1000, - stopping=0.1, - resume=True, - max_iteration=None, - checkpointing=True, - seed=1234, - acceptance_threshold=0.01, - analytic_priors=False, - maximum_uninformed=1000, - uninformed_proposal=None, - uninformed_proposal_kwargs=None, - flow_class=None, - flow_config=None, - training_frequency=None, - reset_weights=False, - reset_permutations=False, - reset_acceptance=False, - train_on_empty=True, - cooldown=100, - memory=False, - poolsize=None, - drawsize=None, - max_poolsize_scale=10, - update_poolsize=False, - latent_prior='truncated_gaussian', - draw_latent_kwargs=None, - compute_radius_with_all=False, - min_radius=False, - max_radius=50, - check_acceptance=False, - fuzz=1.0, - expansion_fraction=1.0, - rescale_parameters=True, - rescale_bounds=[-1, 1], - update_bounds=False, - boundary_inversion=False, - inversion_type='split', detect_edges=False, - detect_edges_kwargs=None, - reparameterisations=None, - n_pool=None, - max_threads=1, - pytorch_threads=None, - plot=False, - proposal_plots=False - ) + self.expected = self.sampler.default_kwargs + self.expected['output'] = 'outdir/label_nessai/' def tearDown(self): del self.likelihood @@ -76,16 +31,16 @@ class TestNessai(unittest.TestCase): del self.sampler del self.expected - def test_default_kwargs(self): - expected = self.expected.copy() - self.assertDictEqual(expected, self.sampler.kwargs) - def test_translate_kwargs_nlive(self): expected = self.expected.copy() + # nlive in the default kwargs is not a fixed but depends on the + # version of nessai, so get the value here and use it when setting + # the equivalent kwargs. + nlive = expected["nlive"] for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() del new_kwargs["nlive"] - new_kwargs[equiv] = 1000 + new_kwargs[equiv] = nlive self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) @@ -117,10 +72,10 @@ class TestNessai(unittest.TestCase): self.sampler.kwargs = new_kwargs self.assertDictEqual(expected, self.sampler.kwargs) - @patch("builtins.open", mock_open(read_data='{"nlive": 2000}')) + @patch("builtins.open", mock_open(read_data='{"nlive": 4000}')) def test_update_from_config_file(self): expected = self.expected.copy() - expected["nlive"] = 2000 + expected["nlive"] = 4000 new_kwargs = self.expected.copy() new_kwargs["config_file"] = "config_file.json" self.sampler.kwargs = new_kwargs