From 88ed4d97215667843b9f83c130d9ae24640a4027 Mon Sep 17 00:00:00 2001 From: Gregory Ashton <gregory.ashton@ligo.org> Date: Wed, 16 Dec 2020 04:36:56 -0600 Subject: [PATCH] Adds functionality to notch data --- bilby/gw/detector/strain_data.py | 138 ++++++++++++++++++++++++++- test/gw/detector/strain_data_test.py | 105 ++++++++++++++++++++ 2 files changed, 240 insertions(+), 3 deletions(-) diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index 6ba6db039..da0a39782 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -31,7 +31,7 @@ class InterferometerStrainData(object): time_array = PropertyAccessor('_times_and_frequencies', 'time_array') def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, - roll_off=0.2): + roll_off=0.2, notch_list=None): """ Initiate an InterferometerStrainData object The initialised object contains no data, this should be added using one @@ -46,11 +46,14 @@ class InterferometerStrainData(object): roll_off: float The roll-off (in seconds) used in the Tukey window, default=0.2s. This corresponds to alpha * duration / 2 for scipy tukey window. + notch_list: bilby.gw.detector.strain_data.NotchList + A list of notches """ self.minimum_frequency = minimum_frequency self.maximum_frequency = maximum_frequency + self.notch_list = notch_list self.roll_off = roll_off self.window_factor = 1 @@ -122,18 +125,46 @@ class InterferometerStrainData(object): self._maximum_frequency = maximum_frequency self._frequency_mask_updated = False + @property + def notch_list(self): + return self._notch_list + + @notch_list.setter + def notch_list(self, notch_list): + """ Set the notch_list + + Parameters + ---------- + notch_list: list, bilby.gw.detector.strain_data.NotchList + A list of length-2 tuples of the (max, min) frequency for the + notches or a pre-made bilby NotchList. + + """ + if notch_list is None: + self._notch_list = NotchList(None) + elif isinstance(notch_list, list): + self._notch_list = NotchList(notch_list) + elif isinstance(notch_list, NotchList): + self._notch_list = notch_list + else: + raise ValueError("notch_list {} not understood".format(notch_list)) + self._frequency_mask_updated = False + @property def frequency_mask(self): - """Masking array for limiting the frequency band. + """ Masking array for limiting the frequency band. Returns ------- - array_like: An array of boolean values + mask: np.ndarray + An array of boolean values """ if not self._frequency_mask_updated: frequency_array = self._times_and_frequencies.frequency_array mask = ((frequency_array >= self.minimum_frequency) & (frequency_array <= self.maximum_frequency)) + for notch in self.notch_list: + mask[notch.get_idxs(frequency_array)] = False self._frequency_mask = mask self._frequency_mask_updated = True return self._frequency_mask @@ -683,3 +714,104 @@ class InterferometerStrainData(object): strain = strain.resample(sampling_frequency) self.set_from_gwpy_timeseries(strain) + + +class Notch(object): + def __init__(self, minimum_frequency, maximum_frequency): + """ A notch object storing the maximum and minimum frequency of the notch + + Parameters + ---------- + minimum_frequency, maximum_frequency: float + The minimum and maximum frequency of the notch + + """ + + if 0 < minimum_frequency < maximum_frequency < np.inf: + self.minimum_frequency = minimum_frequency + self.maximum_frequency = maximum_frequency + else: + msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid" + .format(minimum_frequency, maximum_frequency)) + raise ValueError(msg) + + def get_idxs(self, frequency_array): + """ Get a boolean mask for the frequencies in frequency_array in the notch + + Parameters + ---------- + frequency_array: np.ndarray + An array of frequencies + + Returns + ------- + idxs: np.ndarray + An array of booleans which are True for frequencies in the notch + + """ + lower = (frequency_array > self.minimum_frequency) + upper = (frequency_array < self.maximum_frequency) + return lower & upper + + def check_frequency(self, freq): + """ Check if freq is inside the notch + + Parameters + ---------- + freq: float + The frequency to check + + Returns + ------- + True/False: + If freq inside the notch, return True, else False + """ + + if self.minimum_frequency < freq < self.maximum_frequency: + return True + else: + return False + + +class NotchList(list): + def __init__(self, notch_list): + """ A list of notches + + Parameters + ---------- + notch_list: list + A list of length-2 tuples of the (max, min) frequency for the + notches. + + Raises + ------ + ValueError + If the list is malformed. + """ + + if notch_list is not None: + for notch in notch_list: + if isinstance(notch, tuple) and len(notch) == 2: + self.append(Notch(*notch)) + else: + msg = "notch_list {} is malformed".format(notch_list) + raise ValueError(msg) + + def check_frequency(self, freq): + """ Check if freq is inside the notch list + + Parameters + ---------- + freq: float + The frequency to check + + Returns + ------- + True/False: + If freq inside any of the notches, return True, else False + """ + + for notch in self: + if notch.check_frequency(freq): + return True + return False diff --git a/test/gw/detector/strain_data_test.py b/test/gw/detector/strain_data_test.py index a11e347df..9cba74104 100644 --- a/test/gw/detector/strain_data_test.py +++ b/test/gw/detector/strain_data_test.py @@ -35,6 +35,49 @@ class TestInterferometerStrainData(unittest.TestCase): np.array_equal(self.ifosd.frequency_mask, [False, True, False]) ) + def test_frequency_mask_2(self): + strain_data = bilby.gw.detector.InterferometerStrainData( + minimum_frequency=20, maximum_frequency=512) + strain_data.set_from_time_domain_strain( + time_domain_strain=np.random.normal(0, 1, 4096), + time_array=np.arange(0, 4, 4 / 4096) + ) + + # Test from init + freqs = strain_data.frequency_array[strain_data.frequency_mask] + self.assertTrue(all(freqs >= 20)) + self.assertTrue(all(freqs <= 512)) + + # Test from update + strain_data.minimum_frequency = 30 + strain_data.maximum_frequency = 256 + freqs = strain_data.frequency_array[strain_data.frequency_mask] + self.assertTrue(all(freqs >= 30)) + self.assertTrue(all(freqs <= 256)) + + def test_notches_frequency_mask(self): + strain_data = bilby.gw.detector.InterferometerStrainData( + minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)]) + strain_data.set_from_time_domain_strain( + time_domain_strain=np.random.normal(0, 1, 4096), + time_array=np.arange(0, 4, 4 / 4096) + ) + + # Test from init + freqs = strain_data.frequency_array[strain_data.frequency_mask] + idxs = (freqs > 100) * (freqs < 101) + self.assertTrue(len(freqs[idxs]) == 0) + + # Test from setting + idxs = (freqs > 200) * (freqs < 201) + self.assertTrue(len(freqs[idxs]) > 0) + strain_data.notch_list = [(100, 101), (200, 201)] + freqs = strain_data.frequency_array[strain_data.frequency_mask] + idxs = (freqs > 200) * (freqs < 201) + self.assertTrue(len(freqs[idxs]) == 0) + idxs = (freqs > 100) * (freqs < 101) + self.assertTrue(len(freqs[idxs]) == 0) + def test_set_data_fails(self): with mock.patch("bilby.core.utils.create_frequency_series") as m: m.return_value = [1, 2, 3] @@ -316,5 +359,67 @@ class TestInterferometerStrainDataEquals(unittest.TestCase): self.assertNotEqual(self.ifosd_1, self.ifosd_2) +class TestNotch(unittest.TestCase): + def setUp(self): + self.minimum_frequency = 20 + self.maximum_frequency = 1024 + + def test_init(self): + notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency) + self.assertEqual(notch.minimum_frequency, self.minimum_frequency) + self.assertEqual(notch.maximum_frequency, self.maximum_frequency) + + def test_init_fail(self): + # Infinite frequency + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.Notch(self.minimum_frequency, np.inf) + + # Negative frequency + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.Notch(-10, 1024) + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.Notch(10, -1024) + + # Ordering + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.Notch(30, 20) + + def test_idxs(self): + notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency) + freqs = np.linspace(0, 2048, 100) + idxs = notch.get_idxs(freqs) + self.assertEqual(len(idxs), len(freqs)) + freqs_masked = freqs[idxs] + self.assertTrue(all(freqs_masked > notch.minimum_frequency)) + self.assertTrue(all(freqs_masked < notch.maximum_frequency)) + + +class TestNotchList(unittest.TestCase): + + def test_init_single(self): + notch_list_of_tuples = [(32, 34)] + notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples) + self.assertEqual(len(notch_list), len(notch_list_of_tuples)) + for notch, notch_tuple in zip(notch_list, notch_list_of_tuples): + self.assertEqual(notch.minimum_frequency, notch_tuple[0]) + self.assertEqual(notch.maximum_frequency, notch_tuple[1]) + + def test_init_multiple(self): + notch_list_of_tuples = [(32, 34), (56, 59)] + notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples) + self.assertEqual(len(notch_list), len(notch_list_of_tuples)) + for notch, notch_tuple in zip(notch_list, notch_list_of_tuples): + self.assertEqual(notch.minimum_frequency, notch_tuple[0]) + self.assertEqual(notch.maximum_frequency, notch_tuple[1]) + + def test_init_fail(self): + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.NotchList([20, 30]) + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.NotchList([(30, 20), (20)]) + with self.assertRaises(ValueError): + bilby.gw.detector.strain_data.NotchList([(30, 20, 20)]) + + if __name__ == "__main__": unittest.main() -- GitLab