Skip to content
Snippets Groups Projects
Commit b5b225a0 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Fix missing default kwarg n_effective in dynesty

parent ea619d83
No related branches found
No related tags found
No related merge requests found
......@@ -87,7 +87,7 @@ class Dynesty(NestedSampler):
update_interval=None, print_func=None,
dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True,
save_bounds=False)
save_bounds=False, n_effective=None)
def __init__(self, likelihood, priors, outdir='outdir', label='label',
use_ratio=False, plot=False, skip_import_verification=False,
......@@ -131,7 +131,8 @@ class Dynesty(NestedSampler):
@property
def sampler_function_kwargs(self):
keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
'maxcall', 'logl_max', 'add_live', 'save_bounds']
'maxcall', 'logl_max', 'add_live', 'save_bounds',
'n_effective']
return {key: self.kwargs[key] for key in keys}
@property
......@@ -257,9 +258,29 @@ class Dynesty(NestedSampler):
return self.result
def _run_nested_wrapper(self, kwargs):
""" Wrapper function to run_nested
This wrapper catches exceptions related to different versions of
dynesty accepting different arguments.
Parameters
----------
kwargs: dict
The dictionary of kwargs to pass to run_nested
"""
logger.debug("Calling run_nested with sampler_function_kwargs {}"
.format(kwargs))
try:
self.sampler.run_nested(**kwargs)
except TypeError:
kwargs.pop("n_effective")
self.sampler.run_nested(**kwargs)
def _run_external_sampler_without_checkpointing(self):
logger.debug("Running sampler without checkpointing")
self.sampler.run_nested(**self.sampler_function_kwargs)
self._run_nested_wrapper(self.sampler_function_kwargs)
return self.sampler.results
def _run_external_sampler_with_checkpointing(self):
......@@ -276,7 +297,7 @@ class Dynesty(NestedSampler):
self.start_time = datetime.datetime.now()
while True:
sampler_kwargs['maxcall'] += self.n_check_point
self.sampler.run_nested(**sampler_kwargs)
self._run_nested_wrapper(sampler_kwargs)
if self.sampler.ncall == old_ncall:
break
old_ncall = self.sampler.ncall
......@@ -284,7 +305,7 @@ class Dynesty(NestedSampler):
self.write_current_state()
sampler_kwargs['add_live'] = True
self.sampler.run_nested(**sampler_kwargs)
self._run_nested_wrapper(sampler_kwargs)
return self.sampler.results
def _remove_checkpoint(self):
......
......@@ -155,7 +155,7 @@ class TestDynesty(unittest.TestCase):
enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
walks=20, update_interval=600, print_func='func')
walks=20, update_interval=600, print_func='func', n_effective=None)
self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise
self.assertListEqual([0, 1], self.sampler.kwargs['periodic']) # Check this separately
self.sampler.kwargs['periodic'] = None # The dict comparison can't handle lists
......@@ -173,7 +173,7 @@ class TestDynesty(unittest.TestCase):
enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None,
logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
walks=20, update_interval=600, print_func='func')
walks=20, update_interval=600, print_func='func', n_effective=None)
for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
new_kwargs = self.sampler.kwargs.copy()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment