diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py index 7c8d7e8c34cdac1f11fea2bdd0e5770085b41f13..548e1a9b35c6b1e5a5e86264ad422d00844d28cf 100644 --- a/bilby/core/sampler/base_sampler.py +++ b/bilby/core/sampler/base_sampler.py @@ -446,6 +446,8 @@ class Sampler(object): ========== theta: array_like Parameter values at which to evaluate likelihood + warning: bool + Whether or not to print a warning Returns ======= @@ -453,14 +455,19 @@ class Sampler(object): True if the likelihood and prior are finite, false otherwise """ - bad_values = [np.inf, np.nan_to_num(np.inf), np.nan] - if abs(self.log_prior(theta)) in bad_values: + log_p = self.log_prior(theta) + log_l = self.log_likelihood(theta) + return \ + self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \ + self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood') + + @staticmethod + def _check_bad_value(val, warning, theta, label): + val = np.abs(val) + bad_values = [np.inf, np.nan_to_num(np.inf)] + if val in bad_values or np.isnan(val): if warning: - logger.warning('Prior draw {} has inf prior'.format(theta)) - return False - if abs(self.log_likelihood(theta)) in bad_values: - if warning: - logger.warning('Prior draw {} has inf likelihood'.format(theta)) + logger.warning(f'Prior draw {theta} has inf {label}') return False return True diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py index af3f6b749fc37f8bf22e944727dc0adc2da80868..d14eeaa4998b4f36b318029edb9f0b3d5c49c597 100644 --- a/test/core/sampler/base_sampler_test.py +++ b/test/core/sampler/base_sampler_test.py @@ -98,6 +98,27 @@ class TestSampler(unittest.TestCase): self.sampler.run_sampler() self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__) + def test_bad_value_nan(self): + self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None) + + def test_bad_value_np_abs_nan(self): + self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None) + + def test_bad_value_abs_nan(self): + self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None) + + def test_bad_value_pos_inf(self): + self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None) + + def test_bad_value_neg_inf(self): + self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None) + + def test_bad_value_pos_inf_nan_to_num(self): + self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None) + + def test_bad_value_neg_inf_nan_to_num(self): + self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None) + if __name__ == "__main__": unittest.main()