From b5b225a0ab7e3f4aa53f47f412fa7dce54f3f517 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Mon, 16 Sep 2019 19:19:54 -0500
Subject: [PATCH] Fix missing default kwarg n_effective in dynesty

---
 bilby/core/sampler/dynesty.py | 31 ++++++++++++++++++++++++++-----
 test/sampler_test.py          |  4 ++--
 2 files changed, 28 insertions(+), 7 deletions(-)

diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 4bcafd423..3bd8b55e4 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -87,7 +87,7 @@ class Dynesty(NestedSampler):
                           update_interval=None, print_func=None,
                           dlogz=0.1, maxiter=None, maxcall=None,
                           logl_max=np.inf, add_live=True, print_progress=True,
-                          save_bounds=False)
+                          save_bounds=False, n_effective=None)
 
     def __init__(self, likelihood, priors, outdir='outdir', label='label',
                  use_ratio=False, plot=False, skip_import_verification=False,
@@ -131,7 +131,8 @@ class Dynesty(NestedSampler):
     @property
     def sampler_function_kwargs(self):
         keys = ['dlogz', 'print_progress', 'print_func', 'maxiter',
-                'maxcall', 'logl_max', 'add_live', 'save_bounds']
+                'maxcall', 'logl_max', 'add_live', 'save_bounds',
+                'n_effective']
         return {key: self.kwargs[key] for key in keys}
 
     @property
@@ -257,9 +258,29 @@ class Dynesty(NestedSampler):
 
         return self.result
 
+    def _run_nested_wrapper(self, kwargs):
+        """ Wrapper function to run_nested
+
+        This wrapper catches exceptions related to different versions of
+        dynesty accepting different arguments.
+
+        Parameters
+        ----------
+        kwargs: dict
+            The dictionary of kwargs to pass to run_nested
+
+        """
+        logger.debug("Calling run_nested with sampler_function_kwargs {}"
+                     .format(kwargs))
+        try:
+            self.sampler.run_nested(**kwargs)
+        except TypeError:
+            kwargs.pop("n_effective")
+            self.sampler.run_nested(**kwargs)
+
     def _run_external_sampler_without_checkpointing(self):
         logger.debug("Running sampler without checkpointing")
-        self.sampler.run_nested(**self.sampler_function_kwargs)
+        self._run_nested_wrapper(self.sampler_function_kwargs)
         return self.sampler.results
 
     def _run_external_sampler_with_checkpointing(self):
@@ -276,7 +297,7 @@ class Dynesty(NestedSampler):
         self.start_time = datetime.datetime.now()
         while True:
             sampler_kwargs['maxcall'] += self.n_check_point
-            self.sampler.run_nested(**sampler_kwargs)
+            self._run_nested_wrapper(sampler_kwargs)
             if self.sampler.ncall == old_ncall:
                 break
             old_ncall = self.sampler.ncall
@@ -284,7 +305,7 @@ class Dynesty(NestedSampler):
             self.write_current_state()
 
         sampler_kwargs['add_live'] = True
-        self.sampler.run_nested(**sampler_kwargs)
+        self._run_nested_wrapper(sampler_kwargs)
         return self.sampler.results
 
     def _remove_checkpoint(self):
diff --git a/test/sampler_test.py b/test/sampler_test.py
index 9657cb286..a7cfe5978 100644
--- a/test/sampler_test.py
+++ b/test/sampler_test.py
@@ -155,7 +155,7 @@ class TestDynesty(unittest.TestCase):
                         enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
                         facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None,
                         logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
-                        walks=20, update_interval=600, print_func='func')
+                        walks=20, update_interval=600, print_func='func', n_effective=None)
         self.sampler.kwargs['print_func'] = 'func'  # set this manually as this is not testable otherwise
         self.assertListEqual([0, 1], self.sampler.kwargs['periodic'])  # Check this separately
         self.sampler.kwargs['periodic'] = None  # The dict comparison can't handle lists
@@ -173,7 +173,7 @@ class TestDynesty(unittest.TestCase):
                         enlarge=None, bootstrap=None, vol_dec=0.5, vol_check=2.0,
                         facc=0.5, slices=5, dlogz=0.1, maxiter=None, maxcall=None,
                         logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
-                        walks=20, update_interval=600, print_func='func')
+                        walks=20, update_interval=600, print_func='func', n_effective=None)
 
         for equiv in bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs:
             new_kwargs = self.sampler.kwargs.copy()
-- 
GitLab