From bfb76c9a342754abeb70baac284122613eac0121 Mon Sep 17 00:00:00 2001 From: Moritz Huebner <moritz.huebner@ligo.org> Date: Mon, 17 Jun 2019 22:40:28 -0500 Subject: [PATCH] Added a few tests and fixes a minor issue --- bilby/core/utils.py | 2 +- test/utils_test.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/bilby/core/utils.py b/bilby/core/utils.py index a2ef15fea..ee7af17b4 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -201,7 +201,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > _TOL: + if np.abs(num - np.round(num)) > 10**(-_TOL): raise IllegalDurationAndSamplingFrequencyException( '\nYour sampling frequency and duration must multiply to a number' 'up to (tol = {}) decimals close to an integer number. ' diff --git a/test/utils_test.py b/test/utils_test.py index 3b5645c8f..5ac7e1b96 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -114,6 +114,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): self.time_array) self.assertEqual(self.sampling_frequency, new_sampling_freq) + def test_get_sampling_frequency_from_time_array_unequally_sampled(self): + self.time_array[-1] += 0.0001 + with self.assertRaises(ValueError): + _, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) + def test_get_duration_from_time_array(self): _, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) self.assertEqual(self.duration, new_duration) @@ -127,6 +132,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): self.frequency_array) self.assertEqual(self.sampling_frequency, new_sampling_freq) + def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): + self.frequency_array[-1] += 0.0001 + with self.assertRaises(ValueError): + _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) + def test_get_duration_from_frequency_array(self): _, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array( self.frequency_array) @@ -148,6 +158,12 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): duration=new_duration) self.assertTrue(np.allclose(self.frequency_array, new_frequency_array)) + def test_illegal_sampling_frequency_and_duration(self): + with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): + _ = utils.create_time_series(sampling_frequency=7.7, + duration=1.3, + starting_time=0) + if __name__ == '__main__': unittest.main() -- GitLab