From bfb76c9a342754abeb70baac284122613eac0121 Mon Sep 17 00:00:00 2001
From: Moritz Huebner <moritz.huebner@ligo.org>
Date: Mon, 17 Jun 2019 22:40:28 -0500
Subject: [PATCH] Added a few tests and fixes a minor issue

---
 bilby/core/utils.py |  2 +-
 test/utils_test.py  | 16 ++++++++++++++++
 2 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/bilby/core/utils.py b/bilby/core/utils.py
index a2ef15fea..ee7af17b4 100644
--- a/bilby/core/utils.py
+++ b/bilby/core/utils.py
@@ -201,7 +201,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration):
 
     """
     num = sampling_frequency * duration
-    if np.abs(num - np.round(num)) > _TOL:
+    if np.abs(num - np.round(num)) > 10**(-_TOL):
         raise IllegalDurationAndSamplingFrequencyException(
             '\nYour sampling frequency and duration must multiply to a number'
             'up to (tol = {}) decimals close to an integer number. '
diff --git a/test/utils_test.py b/test/utils_test.py
index 3b5645c8f..5ac7e1b96 100644
--- a/test/utils_test.py
+++ b/test/utils_test.py
@@ -114,6 +114,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
             self.time_array)
         self.assertEqual(self.sampling_frequency, new_sampling_freq)
 
+    def test_get_sampling_frequency_from_time_array_unequally_sampled(self):
+        self.time_array[-1] += 0.0001
+        with self.assertRaises(ValueError):
+            _, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
+
     def test_get_duration_from_time_array(self):
         _, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
         self.assertEqual(self.duration, new_duration)
@@ -127,6 +132,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
             self.frequency_array)
         self.assertEqual(self.sampling_frequency, new_sampling_freq)
 
+    def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self):
+        self.frequency_array[-1] += 0.0001
+        with self.assertRaises(ValueError):
+            _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array)
+
     def test_get_duration_from_frequency_array(self):
         _, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array(
             self.frequency_array)
@@ -148,6 +158,12 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
                                           duration=new_duration)
         self.assertTrue(np.allclose(self.frequency_array, new_frequency_array))
 
+    def test_illegal_sampling_frequency_and_duration(self):
+        with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException):
+            _ = utils.create_time_series(sampling_frequency=7.7,
+                                         duration=1.3,
+                                         starting_time=0)
+
 
 if __name__ == '__main__':
     unittest.main()
-- 
GitLab