diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 6b4f3461932d2b19c4b4396ea21c9a365c9db5cc..3f73cb795840926e3140796656294d51b63f4cc3 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -248,9 +248,10 @@ class Sampler(object): self.exit_code = exit_code + self._log_likelihood_eval_time = np.nan if not soft_init: self._verify_parameters() - self._time_likelihood() + self._log_likelihood_eval_time = self._time_likelihood() self._verify_use_ratio() self.kwargs = kwargs @@ -433,6 +434,10 @@ class Sampler(object): n_evaluations: int The number of evaluations to estimate the evaluation time from + Returns + ======= + log_likelihood_eval_time: float + The time (in s) it took for one likelihood evaluation """ t1 = datetime.datetime.now() @@ -442,15 +447,16 @@ class Sampler(object): )[:, 0] self.log_likelihood(theta) total_time = (datetime.datetime.now() - t1).total_seconds() - self._log_likelihood_eval_time = total_time / n_evaluations + log_likelihood_eval_time = total_time / n_evaluations - if self._log_likelihood_eval_time == 0: - self._log_likelihood_eval_time = np.nan + if log_likelihood_eval_time == 0: + log_likelihood_eval_time = np.nan logger.info("Unable to measure single likelihood time") else: logger.info( f"Single likelihood evaluation took {self._log_likelihood_eval_time:.3e} s" ) + return log_likelihood_eval_time def _verify_use_ratio(self): """ diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index c2d88d5fcc2f38ee21936958e60fd3f430526215..faebfc6bfad1c11a1f116c0d6780c3cfce60aa2a 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -239,8 +239,12 @@ class Dynesty(NestedSampler): self.nestcheck = nestcheck if self.n_check_point is None: - self.n_check_point = max( - int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10 + self.n_check_point = ( + 10 + if np.isnan(self._log_likelihood_eval_time) + else max( + int(check_point_delta_t / self._log_likelihood_eval_time / 10), 10 + ) ) self.check_point_delta_t = check_point_delta_t logger.info(f"Checkpoint every check_point_delta_t = {check_point_delta_t}s") diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index 4856a9e7df4922bf94fe26746972bb1864774535..47cc2003e08da4bd972275de44a57682b8aa79cb 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -12,7 +12,7 @@ from bilby.core import prior class TestSampler(unittest.TestCase): - def setUp(self): + def setUp(self, soft_init=False): likelihood = bilby.core.likelihood.Likelihood() likelihood.parameters = dict(a=1, b=2, c=3) delta_prior = prior.DeltaFunction(peak=0) @@ -36,11 +36,16 @@ class TestSampler(unittest.TestCase): outdir=test_directory, use_ratio=False, skip_import_verification=True, + soft_init=soft_init ) def tearDown(self): del self.sampler + def test_softinit(self): + self.setUp(soft_init=True) + self.assertTrue(hasattr(self.sampler, "_log_likelihood_eval_time")) + def test_search_parameter_keys(self): expected_search_parameter_keys = ["c"] self.assertListEqual(