diff --git a/bilby/core/sampler/base_sampler.py b/bilby/core/sampler/base_sampler.py
index 7c8d7e8c34cdac1f11fea2bdd0e5770085b41f13..548e1a9b35c6b1e5a5e86264ad422d00844d28cf 100644
--- a/bilby/core/sampler/base_sampler.py
+++ b/bilby/core/sampler/base_sampler.py
@@ -446,6 +446,8 @@ class Sampler(object):
         ==========
         theta: array_like
             Parameter values at which to evaluate likelihood
+        warning: bool
+            Whether or not to print a warning
 
         Returns
         =======
@@ -453,14 +455,19 @@ class Sampler(object):
             True if the likelihood and prior are finite, false otherwise
 
         """
-        bad_values = [np.inf, np.nan_to_num(np.inf), np.nan]
-        if abs(self.log_prior(theta)) in bad_values:
+        log_p = self.log_prior(theta)
+        log_l = self.log_likelihood(theta)
+        return \
+            self._check_bad_value(val=log_p, warning=warning, theta=theta, label='prior') and \
+            self._check_bad_value(val=log_l, warning=warning, theta=theta, label='likelihood')
+
+    @staticmethod
+    def _check_bad_value(val, warning, theta, label):
+        val = np.abs(val)
+        bad_values = [np.inf, np.nan_to_num(np.inf)]
+        if val in bad_values or np.isnan(val):
             if warning:
-                logger.warning('Prior draw {} has inf prior'.format(theta))
-            return False
-        if abs(self.log_likelihood(theta)) in bad_values:
-            if warning:
-                logger.warning('Prior draw {} has inf likelihood'.format(theta))
+                logger.warning(f'Prior draw {theta} has inf {label}')
             return False
         return True
 
diff --git a/test/core/sampler/base_sampler_test.py b/test/core/sampler/base_sampler_test.py
index af3f6b749fc37f8bf22e944727dc0adc2da80868..d14eeaa4998b4f36b318029edb9f0b3d5c49c597 100644
--- a/test/core/sampler/base_sampler_test.py
+++ b/test/core/sampler/base_sampler_test.py
@@ -98,6 +98,27 @@ class TestSampler(unittest.TestCase):
         self.sampler.run_sampler()
         self.assertDictEqual(sampler_copy.__dict__, self.sampler.__dict__)
 
+    def test_bad_value_nan(self):
+        self.sampler._check_bad_value(val=np.nan, warning=False, theta=None, label=None)
+
+    def test_bad_value_np_abs_nan(self):
+        self.sampler._check_bad_value(val=np.abs(np.nan), warning=False, theta=None, label=None)
+
+    def test_bad_value_abs_nan(self):
+        self.sampler._check_bad_value(val=abs(np.nan), warning=False, theta=None, label=None)
+
+    def test_bad_value_pos_inf(self):
+        self.sampler._check_bad_value(val=np.inf, warning=False, theta=None, label=None)
+
+    def test_bad_value_neg_inf(self):
+        self.sampler._check_bad_value(val=-np.inf, warning=False, theta=None, label=None)
+
+    def test_bad_value_pos_inf_nan_to_num(self):
+        self.sampler._check_bad_value(val=np.nan_to_num(np.inf), warning=False, theta=None, label=None)
+
+    def test_bad_value_neg_inf_nan_to_num(self):
+        self.sampler._check_bad_value(val=np.nan_to_num(-np.inf), warning=False, theta=None, label=None)
+
 
 if __name__ == "__main__":
     unittest.main()