Commit 9201fde4 authored by Moritz Huebner's avatar Moritz Huebner

Merge branch 'add-notch-functionality' into 'master'

Adds functionality to notch data

See merge request lscsoft/bilby!898
parents 9092f0f2 88ed4d97
......@@ -31,7 +31,7 @@ class InterferometerStrainData(object):
time_array = PropertyAccessor('_times_and_frequencies', 'time_array')
def __init__(self, minimum_frequency=0, maximum_frequency=np.inf,
roll_off=0.2):
roll_off=0.2, notch_list=None):
""" Initiate an InterferometerStrainData object
The initialised object contains no data, this should be added using one
......@@ -46,11 +46,14 @@ class InterferometerStrainData(object):
roll_off: float
The roll-off (in seconds) used in the Tukey window, default=0.2s.
This corresponds to alpha * duration / 2 for scipy tukey window.
notch_list: bilby.gw.detector.strain_data.NotchList
A list of notches
"""
self.minimum_frequency = minimum_frequency
self.maximum_frequency = maximum_frequency
self.notch_list = notch_list
self.roll_off = roll_off
self.window_factor = 1
......@@ -122,18 +125,46 @@ class InterferometerStrainData(object):
self._maximum_frequency = maximum_frequency
self._frequency_mask_updated = False
@property
def notch_list(self):
return self._notch_list
@notch_list.setter
def notch_list(self, notch_list):
""" Set the notch_list
Parameters
----------
notch_list: list, bilby.gw.detector.strain_data.NotchList
A list of length-2 tuples of the (max, min) frequency for the
notches or a pre-made bilby NotchList.
"""
if notch_list is None:
self._notch_list = NotchList(None)
elif isinstance(notch_list, list):
self._notch_list = NotchList(notch_list)
elif isinstance(notch_list, NotchList):
self._notch_list = notch_list
else:
raise ValueError("notch_list {} not understood".format(notch_list))
self._frequency_mask_updated = False
@property
def frequency_mask(self):
"""Masking array for limiting the frequency band.
""" Masking array for limiting the frequency band.
Returns
-------
array_like: An array of boolean values
mask: np.ndarray
An array of boolean values
"""
if not self._frequency_mask_updated:
frequency_array = self._times_and_frequencies.frequency_array
mask = ((frequency_array >= self.minimum_frequency) &
(frequency_array <= self.maximum_frequency))
for notch in self.notch_list:
mask[notch.get_idxs(frequency_array)] = False
self._frequency_mask = mask
self._frequency_mask_updated = True
return self._frequency_mask
......@@ -683,3 +714,104 @@ class InterferometerStrainData(object):
strain = strain.resample(sampling_frequency)
self.set_from_gwpy_timeseries(strain)
class Notch(object):
def __init__(self, minimum_frequency, maximum_frequency):
""" A notch object storing the maximum and minimum frequency of the notch
Parameters
----------
minimum_frequency, maximum_frequency: float
The minimum and maximum frequency of the notch
"""
if 0 < minimum_frequency < maximum_frequency < np.inf:
self.minimum_frequency = minimum_frequency
self.maximum_frequency = maximum_frequency
else:
msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid"
.format(minimum_frequency, maximum_frequency))
raise ValueError(msg)
def get_idxs(self, frequency_array):
""" Get a boolean mask for the frequencies in frequency_array in the notch
Parameters
----------
frequency_array: np.ndarray
An array of frequencies
Returns
-------
idxs: np.ndarray
An array of booleans which are True for frequencies in the notch
"""
lower = (frequency_array > self.minimum_frequency)
upper = (frequency_array < self.maximum_frequency)
return lower & upper
def check_frequency(self, freq):
""" Check if freq is inside the notch
Parameters
----------
freq: float
The frequency to check
Returns
-------
True/False:
If freq inside the notch, return True, else False
"""
if self.minimum_frequency < freq < self.maximum_frequency:
return True
else:
return False
class NotchList(list):
def __init__(self, notch_list):
""" A list of notches
Parameters
----------
notch_list: list
A list of length-2 tuples of the (max, min) frequency for the
notches.
Raises
------
ValueError
If the list is malformed.
"""
if notch_list is not None:
for notch in notch_list:
if isinstance(notch, tuple) and len(notch) == 2:
self.append(Notch(*notch))
else:
msg = "notch_list {} is malformed".format(notch_list)
raise ValueError(msg)
def check_frequency(self, freq):
""" Check if freq is inside the notch list
Parameters
----------
freq: float
The frequency to check
Returns
-------
True/False:
If freq inside any of the notches, return True, else False
"""
for notch in self:
if notch.check_frequency(freq):
return True
return False
......@@ -35,6 +35,49 @@ class TestInterferometerStrainData(unittest.TestCase):
np.array_equal(self.ifosd.frequency_mask, [False, True, False])
)
def test_frequency_mask_2(self):
strain_data = bilby.gw.detector.InterferometerStrainData(
minimum_frequency=20, maximum_frequency=512)
strain_data.set_from_time_domain_strain(
time_domain_strain=np.random.normal(0, 1, 4096),
time_array=np.arange(0, 4, 4 / 4096)
)
# Test from init
freqs = strain_data.frequency_array[strain_data.frequency_mask]
self.assertTrue(all(freqs >= 20))
self.assertTrue(all(freqs <= 512))
# Test from update
strain_data.minimum_frequency = 30
strain_data.maximum_frequency = 256
freqs = strain_data.frequency_array[strain_data.frequency_mask]
self.assertTrue(all(freqs >= 30))
self.assertTrue(all(freqs <= 256))
def test_notches_frequency_mask(self):
strain_data = bilby.gw.detector.InterferometerStrainData(
minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)])
strain_data.set_from_time_domain_strain(
time_domain_strain=np.random.normal(0, 1, 4096),
time_array=np.arange(0, 4, 4 / 4096)
)
# Test from init
freqs = strain_data.frequency_array[strain_data.frequency_mask]
idxs = (freqs > 100) * (freqs < 101)
self.assertTrue(len(freqs[idxs]) == 0)
# Test from setting
idxs = (freqs > 200) * (freqs < 201)
self.assertTrue(len(freqs[idxs]) > 0)
strain_data.notch_list = [(100, 101), (200, 201)]
freqs = strain_data.frequency_array[strain_data.frequency_mask]
idxs = (freqs > 200) * (freqs < 201)
self.assertTrue(len(freqs[idxs]) == 0)
idxs = (freqs > 100) * (freqs < 101)
self.assertTrue(len(freqs[idxs]) == 0)
def test_set_data_fails(self):
with mock.patch("bilby.core.utils.create_frequency_series") as m:
m.return_value = [1, 2, 3]
......@@ -316,5 +359,67 @@ class TestInterferometerStrainDataEquals(unittest.TestCase):
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
class TestNotch(unittest.TestCase):
def setUp(self):
self.minimum_frequency = 20
self.maximum_frequency = 1024
def test_init(self):
notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
self.assertEqual(notch.minimum_frequency, self.minimum_frequency)
self.assertEqual(notch.maximum_frequency, self.maximum_frequency)
def test_init_fail(self):
# Infinite frequency
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(self.minimum_frequency, np.inf)
# Negative frequency
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(-10, 1024)
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(10, -1024)
# Ordering
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.Notch(30, 20)
def test_idxs(self):
notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
freqs = np.linspace(0, 2048, 100)
idxs = notch.get_idxs(freqs)
self.assertEqual(len(idxs), len(freqs))
freqs_masked = freqs[idxs]
self.assertTrue(all(freqs_masked > notch.minimum_frequency))
self.assertTrue(all(freqs_masked < notch.maximum_frequency))
class TestNotchList(unittest.TestCase):
def test_init_single(self):
notch_list_of_tuples = [(32, 34)]
notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
self.assertEqual(len(notch_list), len(notch_list_of_tuples))
for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
self.assertEqual(notch.minimum_frequency, notch_tuple[0])
self.assertEqual(notch.maximum_frequency, notch_tuple[1])
def test_init_multiple(self):
notch_list_of_tuples = [(32, 34), (56, 59)]
notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
self.assertEqual(len(notch_list), len(notch_list_of_tuples))
for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
self.assertEqual(notch.minimum_frequency, notch_tuple[0])
self.assertEqual(notch.maximum_frequency, notch_tuple[1])
def test_init_fail(self):
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([20, 30])
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([(30, 20), (20)])
with self.assertRaises(ValueError):
bilby.gw.detector.strain_data.NotchList([(30, 20, 20)])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment