From b5b225a0ab7e3f4aa53f47f412fa7dce54f3f517 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Mon, 16 Sep 2019 19:19:54 -0500 Subject: [PATCH] Fix missing default kwarg n_effective in dynesty --- bilby/core/sampler/dynesty.py | 31 ++++++++++++++++++++++++++----- test/sampler_test.py | 4 ++-- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 4bcafd423..3bd8b55e4 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -87,7 +87,7 @@ class Dynesty(NestedSampler): update_interval=None, print_func=None, dlogz=0.1, maxiter=None, maxcall=None, logl_max=np.inf, add_live=True, print_progress=True, - save_bounds=False) + save_bounds=False, n_effective=None) def __init__(self, likelihood, priors, outdir='outdir', label='label', use_ratio=False, plot=False, skip_import_verification=False, @@ -131,7 +131,8 @@ class Dynesty(NestedSampler): @property def sampler_function_kwargs(self): keys = ['dlogz', 'print_progress', 'print_func', 'maxiter', - 'maxcall', 'logl_max', 'add_live', 'save_bounds'] + 'maxcall', 'logl_max', 'add_live', 'save_bounds', + 'n_effective'] return {key: self.kwargs[key] for key in keys} @property @@ -257,9 +258,29 @@ class Dynesty(NestedSampler): return self.result + def _run_nested_wrapper(self, kwargs): + """ Wrapper function to run_nested + + This wrapper catches exceptions related to different versions of + dynesty accepting different arguments. + + Parameters + ---------- + kwargs: dict + The dictionary of kwargs to pass to run_nested + + """ + logger.debug("Calling run_nested with sampler_function_kwargs {}" + .format(kwargs)) + try: + self.sampler.run_nested(**kwargs) + except TypeError: + kwargs.pop("n_effective") + self.sampler.run_nested(**kwargs) + def _run_external_sampler_without_checkpointing(self): logger.debug("Running sampler without checkpointing") - self.sampler.run_nested(**self.sampler_function_kwargs) + self._run_nested_wrapper(self.sampler_function_kwargs) return self.sampler.results def _run_external_sampler_with_checkpointing(self): @@ -276,7 +297,7 @@ class Dynesty(NestedSampler): self.start_time = datetime.datetime.now() while True: sampler_kwargs['maxcall'] += self.n_check_point - self.sampler.run_nested(**sampler_kwargs) + self._run_nested_wrapper(sampler_kwargs) if self.sampler.ncall == old_ncall: break old_ncall = self.sampler.ncall @@ -284,7 +305,7 @@ class Dynesty(NestedSampler): self.write_current_state() sampler_kwargs['add_live'] = True - self.sampler.run_nested(**sampler_kwargs) + self._run_nested_wrapper(sampler_kwargs) return self.sampler.results def _remove_checkpoint(self): diff --git a/test/sampler_test.py b/test/sampler_test.py index 9657cb286..a7cfe5978 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -155,7 +155,7 @@ class TestDynesty(unittest.TestCase): enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0, facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None, logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False, - walks=20, update_interval=600, print_func='func') + walks=20, update_interval=600, print_func='func', n_effective=None) self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise self.assertListEqual([0, 1], self.sampler.kwargs['periodic']) # Check this separately self.sampler.kwargs['periodic'] = None # The dict comparison can't handle lists @@ -173,7 +173,7 @@ class TestDynesty(unittest.TestCase): enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0, facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None, logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False, - walks=20, update_interval=600, print_func='func') + walks=20, update_interval=600, print_func='func', n_effective=None) for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs: new_kwargs = self.sampler.kwargs.copy() -- GitLab