Skip to content
Snippets Groups Projects
Commit 3c896cda authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'add_utils_tests' into 'master'

Added a few tests and fixes a minor issue

See merge request lscsoft/bilby!528
parents f9c24258 bfb76c9a
No related branches found
No related tags found
No related merge requests found
......@@ -201,7 +201,7 @@ def _check_legal_sampling_frequency_and_duration(sampling_frequency, duration):
"""
num = sampling_frequency * duration
if np.abs(num - np.round(num)) > _TOL:
if np.abs(num - np.round(num)) > 10**(-_TOL):
raise IllegalDurationAndSamplingFrequencyException(
'\nYour sampling frequency and duration must multiply to a number'
'up to (tol = {}) decimals close to an integer number. '
......
......@@ -114,6 +114,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
self.time_array)
self.assertEqual(self.sampling_frequency, new_sampling_freq)
def test_get_sampling_frequency_from_time_array_unequally_sampled(self):
self.time_array[-1] += 0.0001
with self.assertRaises(ValueError):
_, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
def test_get_duration_from_time_array(self):
_, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
self.assertEqual(self.duration, new_duration)
......@@ -127,6 +132,11 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
self.frequency_array)
self.assertEqual(self.sampling_frequency, new_sampling_freq)
def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self):
self.frequency_array[-1] += 0.0001
with self.assertRaises(ValueError):
_, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array)
def test_get_duration_from_frequency_array(self):
_, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array(
self.frequency_array)
......@@ -148,6 +158,12 @@ class TestTimeAndFrequencyArrays(unittest.TestCase):
duration=new_duration)
self.assertTrue(np.allclose(self.frequency_array, new_frequency_array))
def test_illegal_sampling_frequency_and_duration(self):
with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException):
_ = utils.create_time_series(sampling_frequency=7.7,
duration=1.3,
starting_time=0)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment