diff --git a/bilby/core/utils.py b/bilby/core/utils.py index a2ef15fea9eab89161ac0ff5f229253eb1488006..ee7af17b40a81eede8f7d2fe39307bbf363d2745 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -201,7 +201,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration): """ num = sampling_frequency * duration - if np.abs(num - np.round(num)) > _TOL: + if np.abs(num - np.round(num)) > 10**(-_TOL): raise IllegalDurationAndSamplingFrequencyException( '\nYour sampling frequency and duration must multiply to a number' 'up to (tol = {}) decimals close to an integer number. ' diff --git a/test/utils_test.py b/test/utils_test.py index 3b5645c8f0124a2790155bda3aa11569b3543f9e..5ac7e1b9638e60fbddbf392e36330f03e578e685 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -114,6 +114,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): self.time_array) self.assertEqual(self.sampling_frequency, new_sampling_freq) + def test_get_sampling_frequency_from_time_array_unequally_sampled(self): + self.time_array[-1] += 0.0001 + with self.assertRaises(ValueError): + _, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) + def test_get_duration_from_time_array(self): _, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array) self.assertEqual(self.duration, new_duration) @@ -127,6 +132,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): self.frequency_array) self.assertEqual(self.sampling_frequency, new_sampling_freq) + def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self): + self.frequency_array[-1] += 0.0001 + with self.assertRaises(ValueError): + _, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array) + def test_get_duration_from_frequency_array(self): _, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array( self.frequency_array) @@ -148,6 +158,12 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): duration=new_duration) self.assertTrue(np.allclose(self.frequency_array, new_frequency_array)) + def test_illegal_sampling_frequency_and_duration(self): + with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException): + _ = utils.create_time_series(sampling_frequency=7.7, + duration=1.3, + starting_time=0) + if __name__ == '__main__': unittest.main()