diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 0e5bddfad198906701191d53fdd6061ec4653e4c..ccad84046b68957746175eaf2d7fa52895d9f77b 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -525,6 +525,7 @@ class Sampler(object): class NestedSampler(Sampler): npoints_equiv_kwargs = ['nlive', 'nlives', 'n_live_points', 'npoints', 'npoint', 'Nlive'] + walks_equiv_kwargs = ['walks', 'steps', 'nmcmc'] def reorder_loglikelihoods(self, unsorted_loglikelihoods, unsorted_samples, sorted_samples): diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 62526471903836bdcbcd3f7c958405889d1f9403..5b2222e1826cba1473aeb671ac9d6d6a7fa64df2 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -149,6 +149,10 @@ class Dynesty(NestedSampler): if 'print_progress' not in kwargs: if 'verbose' in kwargs: kwargs['print_progress'] = kwargs.pop('verbose') + if 'walks' not in kwargs: + for equiv in self.walks_equiv_kwargs: + if equiv in kwargs: + kwargs['walks'] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): if not self.kwargs['walks']: diff --git a/bilby/core/sampler/nestle.py b/bilby/core/sampler/nestle.py index e7cf8b20855017eab263811b240e7bc62900ba9b..0b97daf72b342a124183758fe212ea41bd1e32db 100644 --- a/bilby/core/sampler/nestle.py +++ b/bilby/core/sampler/nestle.py @@ -29,13 +29,17 @@ class Nestle(NestedSampler): default_kwargs = dict(verbose=True, method='multi', npoints=500, update_interval=None, npdim=None, maxiter=None, maxcall=None, dlogz=None, decline_factor=None, - rstate=None, callback=None) + rstate=None, callback=None, steps=20, enlarge=1.2) def _translate_kwargs(self, kwargs): if 'npoints' not in kwargs: for equiv in self.npoints_equiv_kwargs: if equiv in kwargs: kwargs['npoints'] = kwargs.pop(equiv) + if 'steps' not in kwargs: + for equiv in self.walks_equiv_kwargs: + if equiv in kwargs: + kwargs['steps'] = kwargs.pop(equiv) def _verify_kwargs_against_default_kwargs(self): if self.kwargs['verbose']: diff --git a/test/sampler_test.py b/test/sampler_test.py index accb9bd2cc820e927250ee700c301e605422f803..6bb4a4a24d51d91a1e1fd36e182c21a5a711deef 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -240,14 +240,14 @@ class TestNestle(unittest.TestCase): expected = dict(verbose=False, method='multi', npoints=500, update_interval=None, npdim=None, maxiter=None, maxcall=None, dlogz=None, decline_factor=None, - rstate=None, callback=None) + rstate=None, callback=None, steps=20, enlarge=1.2) self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): expected = dict(verbose=False, method='multi', npoints=345, update_interval=None, npdim=None, maxiter=None, maxcall=None, dlogz=None, decline_factor=None, - rstate=None, callback=None) + rstate=None, callback=None, steps=20, enlarge=1.2) self.sampler.kwargs['npoints'] = 123 for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy()