-
Gregory Ashton authoredGregory Ashton authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
utils_test.py 8.03 KiB
from __future__ import absolute_import, division
import unittest
import numpy as np
from astropy import constants
import bilby
from bilby.core import utils
class TestConstants(unittest.TestCase):
def test_speed_of_light(self):
self.assertTrue(bilby.core.utils.speed_of_light, constants.c.value)
def test_parsec(self):
self.assertTrue(bilby.core.utils.parsec, constants.pc.value)
def test_solar_mass(self):
self.assertTrue(bilby.core.utils.solar_mass, constants.M_sun.value)
def test_radius_of_earth(self):
self.assertTrue(bilby.core.utils.radius_of_earth, constants.R_earth.value)
class TestFFT(unittest.TestCase):
def setUp(self):
self.sampling_frequency = 10
def tearDown(self):
del self.sampling_frequency
def test_nfft_sine_function(self):
injected_frequency = 2.7324
duration = 100
times = utils.create_time_series(self.sampling_frequency, duration)
time_domain_strain = np.sin(2 * np.pi * times * injected_frequency + 0.4)
frequency_domain_strain, frequencies = bilby.core.utils.nfft(time_domain_strain, self.sampling_frequency)
frequency_at_peak = frequencies[np.argmax(np.abs(frequency_domain_strain))]
self.assertAlmostEqual(injected_frequency, frequency_at_peak, places=1)
def test_nfft_infft(self):
time_domain_strain = np.random.normal(0, 1, 10)
frequency_domain_strain, _ = bilby.core.utils.nfft(time_domain_strain, self.sampling_frequency)
new_time_domain_strain = bilby.core.utils.infft(frequency_domain_strain, self.sampling_frequency)
self.assertTrue(np.allclose(time_domain_strain, new_time_domain_strain))
class TestInferParameters(unittest.TestCase):
def setUp(self):
def source_function(freqs, a, b, *args, **kwargs):
return None
class TestClass:
def test_method(self, a, b, *args, **kwargs):
pass
self.source1 = source_function
test_obj = TestClass()
self.source2 = test_obj.test_method
def tearDown(self):
del self.source1
del self.source2
def test_args_kwargs_handling(self):
expected = ['a', 'b']
actual = utils.infer_parameters_from_function(self.source1)
self.assertListEqual(expected, actual)
def test_self_handling(self):
expected = ['a', 'b']
actual = utils.infer_args_from_method(self.source2)
self.assertListEqual(expected, actual)
class TestTimeAndFrequencyArrays(unittest.TestCase):
def setUp(self):
self.start_time = 1.3
self.sampling_frequency = 5
self.duration = 1.6
self.frequency_array = utils.create_frequency_series(sampling_frequency=self.sampling_frequency,
duration=self.duration)
self.time_array = utils.create_time_series(sampling_frequency=self.sampling_frequency,
duration=self.duration,
starting_time=self.start_time)
def tearDown(self):
del self.start_time
del self.sampling_frequency
del self.duration
del self.frequency_array
del self.time_array
def test_create_time_array(self):
expected_time_array = np.array([1.3, 1.5, 1.7, 1.9, 2.1, 2.3, 2.5, 2.7])
time_array = utils.create_time_series(sampling_frequency=self.sampling_frequency,
duration=self.duration, starting_time=self.start_time)
self.assertTrue(np.allclose(expected_time_array, time_array))
def test_create_frequency_array(self):
expected_frequency_array = np.array([0.0, 0.625, 1.25, 1.875, 2.5])
frequency_array = utils.create_frequency_series(sampling_frequency=self.sampling_frequency,
duration=self.duration)
self.assertTrue(np.allclose(expected_frequency_array, frequency_array))
def test_get_sampling_frequency_from_time_array(self):
new_sampling_freq, _ = utils.get_sampling_frequency_and_duration_from_time_array(
self.time_array)
self.assertEqual(self.sampling_frequency, new_sampling_freq)
def test_get_sampling_frequency_from_time_array_unequally_sampled(self):
self.time_array[-1] += 0.0001
with self.assertRaises(ValueError):
_, _ = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
def test_get_duration_from_time_array(self):
_, new_duration = utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
self.assertEqual(self.duration, new_duration)
def test_get_start_time_from_time_array(self):
new_start_time = self.time_array[0]
self.assertEqual(self.start_time, new_start_time)
def test_get_sampling_frequency_from_frequency_array(self):
new_sampling_freq, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(
self.frequency_array)
self.assertEqual(self.sampling_frequency, new_sampling_freq)
def test_get_sampling_frequency_from_frequency_array_unequally_sampled(self):
self.frequency_array[-1] += 0.0001
with self.assertRaises(ValueError):
_, _ = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array)
def test_get_duration_from_frequency_array(self):
_, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array(
self.frequency_array)
self.assertEqual(self.duration, new_duration)
def test_consistency_time_array_to_time_array(self):
new_sampling_frequency, new_duration = \
utils.get_sampling_frequency_and_duration_from_time_array(self.time_array)
new_start_time = self.time_array[0]
new_time_array = utils.create_time_series(sampling_frequency=new_sampling_frequency,
duration=new_duration,
starting_time=new_start_time)
self.assertTrue(np.allclose(self.time_array, new_time_array))
def test_consistency_frequency_array_to_frequency_array(self):
new_sampling_frequency, new_duration = utils.get_sampling_frequency_and_duration_from_frequency_array(self.frequency_array)
new_frequency_array = \
utils.create_frequency_series(sampling_frequency=new_sampling_frequency,
duration=new_duration)
self.assertTrue(np.allclose(self.frequency_array, new_frequency_array))
def test_illegal_sampling_frequency_and_duration(self):
with self.assertRaises(utils.IllegalDurationAndSamplingFrequencyException):
_ = utils.create_time_series(sampling_frequency=7.7,
duration=1.3,
starting_time=0)
class TestReflect(unittest.TestCase):
def test_in_range(self):
xprime = np.array([0.1, 0.5, 0.9])
x = np.array([0.1, 0.5, 0.9])
self.assertTrue(
np.testing.assert_allclose(utils.reflect(xprime), x) is None)
def test_in_one_to_two(self):
xprime = np.array([1.1, 1.5, 1.9])
x = np.array([0.9, 0.5, 0.1])
self.assertTrue(
np.testing.assert_allclose(utils.reflect(xprime), x) is None)
def test_in_two_to_three(self):
xprime = np.array([2.1, 2.5, 2.9])
x = np.array([0.1, 0.5, 0.9])
self.assertTrue(
np.testing.assert_allclose(utils.reflect(xprime), x) is None)
def test_in_minus_one_to_zero(self):
xprime = np.array([-0.9, -0.5, -0.1])
x = np.array([0.9, 0.5, 0.1])
self.assertTrue(
np.testing.assert_allclose(utils.reflect(xprime), x) is None)
def test_in_minus_two_to_minus_one(self):
xprime = np.array([-1.9, -1.5, -1.1])
x = np.array([0.1, 0.5, 0.9])
self.assertTrue(
np.testing.assert_allclose(utils.reflect(xprime), x) is None)
if __name__ == '__main__':
unittest.main()