Skip to content
Snippets Groups Projects
Commit 3123b4ad authored by Moritz Huebner's avatar Moritz Huebner
Browse files

Resolve "check_draw doesn't catch nan log_likelihood values."

parent e9d72733
No related branches found
No related tags found
1 merge request!965Resolve "check_draw doesn't catch nan log_likelihood values."
......@@ -446,6 +446,8 @@ class Sampler(object):
==========
theta: array_like
Parameter values at which to evaluate likelihood
warning: bool
Whether or not to print a warning
Returns
=======
......@@ -453,14 +455,19 @@ class Sampler(object):
True if the likelihood and prior are finite, false otherwise
"""
bad_values = [np.inf, np.nan_to_num(np.inf), np.nan]
if abs(self.log_prior(theta)) in bad_values:
log_p = self.log_prior(theta)
log_l = self.log_likelihood(theta)
return \
self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
@staticmethod
def _check_bad_value(val, warning, theta, label):
val = np.abs(val)
bad_values = [np.inf, np.nan_to_num(np.inf)]
if val in bad_values or np.isnan(val):
if warning:
logger.warning('Prior draw {} has inf prior'.format(theta))
return False
if abs(self.log_likelihood(theta)) in bad_values:
if warning:
logger.warning('Prior draw {} has inf likelihood'.format(theta))
logger.warning(f'Prior draw {theta} has inf {label}')
return False
return True
......
......@@ -98,6 +98,27 @@ class TestSampler(unittest.TestCase):
self.sampler.run_sampler()
self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__)
def test_bad_value_nan(self):
self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None)
def test_bad_value_np_abs_nan(self):
self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None)
def test_bad_value_abs_nan(self):
self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None)
def test_bad_value_pos_inf(self):
self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None)
def test_bad_value_neg_inf(self):
self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None)
def test_bad_value_pos_inf_nan_to_num(self):
self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None)
def test_bad_value_neg_inf_nan_to_num(self):
self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment