From 3123b4ad89df3233c98e1a551d0d530701204a1d Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Thu, 3 Jun 2021 01:34:57 +0000 Subject: [PATCH] Resolve "check_draw doesn't catch nan log_likelihood values." --- bilby/core/sampler/base_sampler.py | 21 ++++++++++++++------- test/core/sampler/base_sampler_test.py | 21 +++++++++++++++++++++ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 7c8d7e8c3..548e1a9b3 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -446,6 +446,8 @@ class Sampler(object): ========== theta: array_like Parameter values at which to evaluate likelihood + warning: bool + Whether or not to print a warning Returns ======= @@ -453,14 +455,19 @@ class Sampler(object): True if the likelihood and prior are finite, false otherwise """ - bad_values = [np.inf, np.nan_to_num(np.inf), np.nan] - if abs(self.log_prior(theta)) in bad_values: + log_p = self.log_prior(theta) + log_l = self.log_likelihood(theta) + return \ + self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \ + self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood') + + @staticmethod + def _check_bad_value(val, warning, theta, label): + val = np.abs(val) + bad_values = [np.inf, np.nan_to_num(np.inf)] + if val in bad_values or np.isnan(val): if warning: - logger.warning('Prior draw {} has inf prior'.format(theta)) - return False - if abs(self.log_likelihood(theta)) in bad_values: - if warning: - logger.warning('Prior draw {} has inf likelihood'.format(theta)) + logger.warning(f'Prior draw {theta} has inf {label}') return False return True diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index af3f6b749..d14eeaa49 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -98,6 +98,27 @@ class TestSampler(unittest.TestCase): self.sampler.run_sampler() self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__) + def test_bad_value_nan(self): + self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None) + + def test_bad_value_np_abs_nan(self): + self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None) + + def test_bad_value_abs_nan(self): + self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None) + + def test_bad_value_pos_inf(self): + self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None) + + def test_bad_value_neg_inf(self): + self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None) + + def test_bad_value_pos_inf_nan_to_num(self): + self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None) + + def test_bad_value_neg_inf_nan_to_num(self): + self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None) + if __name__ == "__main__": unittest.main() -- GitLab