From 7605bff05b3b7f22fcee6677fd4942dd8fb3018c Mon Sep 17 00:00:00 2001
From: Alexandre Goettel <alexandresebastien.goettel@ligo.org>
Date: Fri, 15 Mar 2024 21:54:33 +0000
Subject: [PATCH] BUG: using the soft_init sampler kwargs with dynesty caused
 an AttributeError

---
 bilby/core/sampler/base_sampler.py     | 14 ++++++++++----
 bilby/core/sampler/dynesty.py          |  8 ++++++--
 test/core/sampler/base_sampler_test.py |  7 ++++++-
 3 files changed, 22 insertions(+), 7 deletions(-)

diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 6b4f34619..3f73cb795 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -248,9 +248,10 @@ class Sampler(object):
 
         self.exit_code = exit_code
 
+        self._log_likelihood_eval_time = np.nan
         if not soft_init:
             self._verify_parameters()
-            self._time_likelihood()
+            self._log_likelihood_eval_time = self._time_likelihood()
             self._verify_use_ratio()
 
         self.kwargs = kwargs
@@ -433,6 +434,10 @@ class Sampler(object):
         n_evaluations: int
             The number of evaluations to estimate the evaluation time from
 
+        Returns
+        =======
+        log_likelihood_eval_time: float
+            The time (in s) it took for one likelihood evaluation
         """
 
         t1 = datetime.datetime.now()
@@ -442,15 +447,16 @@ class Sampler(object):
             )[:, 0]
             self.log_likelihood(theta)
         total_time = (datetime.datetime.now() - t1).total_seconds()
-        self._log_likelihood_eval_time = total_time / n_evaluations
+        log_likelihood_eval_time = total_time / n_evaluations
 
-        if self._log_likelihood_eval_time == 0:
-            self._log_likelihood_eval_time = np.nan
+        if log_likelihood_eval_time == 0:
+            log_likelihood_eval_time = np.nan
             logger.info("Unable to measure single likelihood time")
         else:
             logger.info(
                 f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s"
             )
+        return log_likelihood_eval_time
 
     def _verify_use_ratio(self):
         """
diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index c2d88d5fc..faebfc6bf 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -239,8 +239,12 @@ class Dynesty(NestedSampler):
         self.nestcheck = nestcheck
 
         if self.n_check_point is None:
-            self.n_check_point = max(
-                int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10
+            self.n_check_point = (
+                10
+                if np.isnan(self._log_likelihood_eval_time)
+                else max(
+                    int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10
+                )
             )
         self.check_point_delta_t = check_point_delta_t
         logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s")
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index 4856a9e7d..47cc2003e 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -12,7 +12,7 @@ from bilby.core import prior
 
 
 class TestSampler(unittest.TestCase):
-    def setUp(self):
+    def setUp(self, soft_init=False):
         likelihood = bilby.core.likelihood.Likelihood()
         likelihood.parameters = dict(a=1, b=2, c=3)
         delta_prior = prior.DeltaFunction(peak=0)
@@ -36,11 +36,16 @@ class TestSampler(unittest.TestCase):
             outdir=test_directory,
             use_ratio=False,
             skip_import_verification=True,
+            soft_init=soft_init
         )
 
     def tearDown(self):
         del self.sampler
 
+    def test_softinit(self):
+        self.setUp(soft_init=True)
+        self.assertTrue(hasattr(self.sampler, "_log_likelihood_eval_time"))
+
     def test_search_parameter_keys(self):
         expected_search_parameter_keys = ["c"]
         self.assertListEqual(
-- 
GitLab