Skip to content
Snippets Groups Projects
Commit 88ed4d97 authored by Gregory Ashton's avatar Gregory Ashton Committed by Moritz Huebner
Browse files

Adds functionality to notch data

parent d57b7ebb
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,7 @@ class InterferometerStrainData(object):
time_array = PropertyAccessor('_times_and_frequencies', 'time_array')
def __init__(self, minimum_frequency=0, maximum_frequency=np.inf,
roll_off=0.2):
roll_off=0.2, notch_list=None):
""" Initiate an InterferometerStrainData object
The initialised object contains no data, this should be added using one
......@@ -46,11 +46,14 @@ class InterferometerStrainData(object):
roll_off: float
The roll-off (in seconds) used in the Tukey window, default=0.2s.
This corresponds to alpha * duration / 2 for scipy tukey window.
notch_list: bilby.gw.detector.strain_data.NotchList
A list of notches
"""
self.minimum_frequency = minimum_frequency
self.maximum_frequency = maximum_frequency
self.notch_list = notch_list
self.roll_off = roll_off
self.window_factor = 1
......@@ -122,18 +125,46 @@ class InterferometerStrainData(object):
self._maximum_frequency = maximum_frequency
self._frequency_mask_updated = False
@property
def notch_list(self):
return self._notch_list
@notch_list.setter
def notch_list(self, notch_list):
""" Set the notch_list
Parameters
----------
notch_list: list, bilby.gw.detector.strain_data.NotchList
A list of length-2 tuples of the (max, min) frequency for the
notches or a pre-made bilby NotchList.
"""
if notch_list is None:
self._notch_list = NotchList(None)
elif isinstance(notch_list, list):
self._notch_list = NotchList(notch_list)
elif isinstance(notch_list, NotchList):
self._notch_list = notch_list
else:
raise ValueError("notch_list {} not understood".format(notch_list))
self._frequency_mask_updated = False
@property
def frequency_mask(self):
"""Masking array for limiting the frequency band.
""" Masking array for limiting the frequency band.
Returns
-------
array_like: An array of boolean values
mask: np.ndarray
An array of boolean values
"""
if not self._frequency_mask_updated:
frequency_array = self._times_and_frequencies.frequency_array
mask = ((frequency_array >= self.minimum_frequency) &
(frequency_array <= self.maximum_frequency))
for notch in self.notch_list:
mask[notch.get_idxs(frequency_array)] = False
self._frequency_mask = mask
self._frequency_mask_updated = True
return self._frequency_mask
......@@ -683,3 +714,104 @@ class InterferometerStrainData(object):
strain = strain.resample(sampling_frequency)
self.set_from_gwpy_timeseries(strain)
class Notch(object):
def __init__(self, minimum_frequency, maximum_frequency):
""" A notch object storing the maximum and minimum frequency of the notch
Parameters
----------
minimum_frequency, maximum_frequency: float
The minimum and maximum frequency of the notch
"""
if 0 < minimum_frequency < maximum_frequency < np.inf:
self.minimum_frequency = minimum_frequency
self.maximum_frequency = maximum_frequency
else:
msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid"
.format(minimum_frequency, maximum_frequency))
raise ValueError(msg)
def get_idxs(self, frequency_array):
""" Get a boolean mask for the frequencies in frequency_array in the notch
Parameters
----------
frequency_array: np.ndarray
An array of frequencies
Returns
-------
idxs: np.ndarray
An array of booleans which are True for frequencies in the notch
"""
lower = (frequency_array > self.minimum_frequency)
upper = (frequency_array < self.maximum_frequency)
return lower & upper
def check_frequency(self, freq):
""" Check if freq is inside the notch
Parameters
----------
freq: float
The frequency to check
Returns
-------
True/False:
If freq inside the notch, return True, else False
"""
if self.minimum_frequency < freq < self.maximum_frequency:
return True
else:
return False
class NotchList(list):
def __init__(self, notch_list):
""" A list of notches
Parameters
----------
notch_list: list
A list of length-2 tuples of the (max, min) frequency for the
notches.
Raises
------
ValueError
If the list is malformed.
"""
if notch_list is not None:
for notch in notch_list:
if isinstance(notch, tuple) and len(notch) == 2:
self.append(Notch(*notch))
else:
msg = "notch_list {} is malformed".format(notch_list)
raise ValueError(msg)
def check_frequency(self, freq):
""" Check if freq is inside the notch list
Parameters
----------
freq: float
The frequency to check
Returns
-------
True/False:
If freq inside any of the notches, return True, else False
"""
for notch in self:
if notch.check_frequency(freq):
return True
return False
......@@ -35,6 +35,49 @@ class TestInterferometerStrainData(unittest.TestCase):
np.array_equal(self.ifosd.frequency_mask, [False, True, False])
)
def test_frequency_mask_2(self):
strain_data = bilby.gw.detector.InterferometerStrainData(
minimum_frequency=20, maximum_frequency=512)
strain_data.set_from_time_domain_strain(
time_domain_strain=np.random.normal(0, 1, 4096),
time_array=np.arange(0, 4, 4 / 4096)
)
# Test from init
freqs = strain_data.frequency_array[strain_data.frequency_mask]
self.assertTrue(all(freqs >= 20))
self.assertTrue(all(freqs <= 512))
# Test from update
strain_data.minimum_frequency = 30
strain_data.maximum_frequency = 256
freqs = strain_data.frequency_array[strain_data.frequency_mask]
self.assertTrue(all(freqs >= 30))
self.assertTrue(all(freqs <= 256))
def test_notches_frequency_mask(self):
strain_data = bilby.gw.detector.InterferometerStrainData(
minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)])
strain_data.set_from_time_domain_strain(
time_domain_strain=np.random.normal(0, 1, 4096),
time_array=np.arange(0, 4, 4 / 4096)
)
# Test from init
freqs = strain_data.frequency_array[strain_data.frequency_mask]
idxs = (freqs > 100) * (freqs < 101)
self.assertTrue(len(freqs[idxs]) == 0)
# Test from setting
idxs = (freqs > 200) * (freqs < 201)
self.assertTrue(len(freqs[idxs]) > 0)
strain_data.notch_list = [(100, 101), (200, 201)]
freqs = strain_data.frequency_array[strain_data.frequency_mask]
idxs = (freqs > 200) * (freqs < 201)
self.assertTrue(len(freqs[idxs]) == 0)
idxs = (freqs > 100) * (freqs < 101)
self.assertTrue(len(freqs[idxs]) == 0)
def test_set_data_fails(self):
with mock.patch("bilby.core.utils.create_frequency_series") as m:
m.return_value = [1, 2, 3]
......@@ -316,5 +359,67 @@ class TestInterferometerStrainDataEquals(unittest.TestCase):
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
class TestNotch(unittest.TestCase):
def setUp(self):
self.minimum_frequency = 20
self.maximum_frequency = 1024
def test_init(self):
notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
self.assertEqual(notch.minimum_frequency, self.minimum_frequency)
self.assertEqual(notch.maximum_frequency, self.maximum_frequency)
def test_init_fail(self):
# Infinite frequency
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(self.minimum_frequency, np.inf)
# Negative frequency
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(-10, 1024)
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(10, -1024)
# Ordering
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(30, 20)
def test_idxs(self):
notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
freqs = np.linspace(0, 2048, 100)
idxs = notch.get_idxs(freqs)
self.assertEqual(len(idxs), len(freqs))
freqs_masked = freqs[idxs]
self.assertTrue(all(freqs_masked > notch.minimum_frequency))
self.assertTrue(all(freqs_masked < notch.maximum_frequency))
class TestNotchList(unittest.TestCase):
def test_init_single(self):
notch_list_of_tuples = [(32, 34)]
notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
self.assertEqual(len(notch_list), len(notch_list_of_tuples))
for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
self.assertEqual(notch.minimum_frequency, notch_tuple[0])
self.assertEqual(notch.maximum_frequency, notch_tuple[1])
def test_init_multiple(self):
notch_list_of_tuples = [(32, 34), (56, 59)]
notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
self.assertEqual(len(notch_list), len(notch_list_of_tuples))
for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
self.assertEqual(notch.minimum_frequency, notch_tuple[0])
self.assertEqual(notch.maximum_frequency, notch_tuple[1])
def test_init_fail(self):
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([20, 30])
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([(30, 20), (20)])
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([(30, 20, 20)])
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment