From 88ed4d97215667843b9f83c130d9ae24640a4027 Mon Sep 17 00:00:00 2001
From: Gregory Ashton <gregory.ashton@ligo.org>
Date: Wed, 16 Dec 2020 04:36:56 -0600
Subject: [PATCH] Adds functionality to notch data

---
 bilby/gw/detector/strain_data.py     | 138 ++++++++++++++++++++++++++-
 test/gw/detector/strain_data_test.py | 105 ++++++++++++++++++++
 2 files changed, 240 insertions(+), 3 deletions(-)

diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py
index 6ba6db039..da0a39782 100644
--- a/bilby/gw/detector/strain_data.py
+++ b/bilby/gw/detector/strain_data.py
@@ -31,7 +31,7 @@ class InterferometerStrainData(object):
     time_array = PropertyAccessor('_times_and_frequencies', 'time_array')
 
     def __init__(self, minimum_frequency=0, maximum_frequency=np.inf,
-                 roll_off=0.2):
+                 roll_off=0.2, notch_list=None):
         """ Initiate an InterferometerStrainData object
 
         The initialised object contains no data, this should be added using one
@@ -46,11 +46,14 @@ class InterferometerStrainData(object):
         roll_off: float
             The roll-off (in seconds) used in the Tukey window, default=0.2s.
             This corresponds to alpha * duration / 2 for scipy tukey window.
+        notch_list: bilby.gw.detector.strain_data.NotchList
+            A list of notches
 
         """
 
         self.minimum_frequency = minimum_frequency
         self.maximum_frequency = maximum_frequency
+        self.notch_list = notch_list
         self.roll_off = roll_off
         self.window_factor = 1
 
@@ -122,18 +125,46 @@ class InterferometerStrainData(object):
         self._maximum_frequency = maximum_frequency
         self._frequency_mask_updated = False
 
+    @property
+    def notch_list(self):
+        return self._notch_list
+
+    @notch_list.setter
+    def notch_list(self, notch_list):
+        """ Set the notch_list
+
+        Parameters
+        ----------
+        notch_list: list, bilby.gw.detector.strain_data.NotchList
+            A list of length-2 tuples of the (max, min) frequency for the
+            notches or a pre-made bilby NotchList.
+
+        """
+        if notch_list is None:
+            self._notch_list = NotchList(None)
+        elif isinstance(notch_list, list):
+            self._notch_list = NotchList(notch_list)
+        elif isinstance(notch_list, NotchList):
+            self._notch_list = notch_list
+        else:
+            raise ValueError("notch_list {} not understood".format(notch_list))
+        self._frequency_mask_updated = False
+
     @property
     def frequency_mask(self):
-        """Masking array for limiting the frequency band.
+        """ Masking array for limiting the frequency band.
 
         Returns
         -------
-        array_like: An array of boolean values
+        mask: np.ndarray
+            An array of boolean values
         """
         if not self._frequency_mask_updated:
             frequency_array = self._times_and_frequencies.frequency_array
             mask = ((frequency_array >= self.minimum_frequency) &
                     (frequency_array <= self.maximum_frequency))
+            for notch in self.notch_list:
+                mask[notch.get_idxs(frequency_array)] = False
             self._frequency_mask = mask
             self._frequency_mask_updated = True
         return self._frequency_mask
@@ -683,3 +714,104 @@ class InterferometerStrainData(object):
         strain = strain.resample(sampling_frequency)
 
         self.set_from_gwpy_timeseries(strain)
+
+
+class Notch(object):
+    def __init__(self, minimum_frequency, maximum_frequency):
+        """ A notch object storing the maximum and minimum frequency of the notch
+
+        Parameters
+        ----------
+        minimum_frequency, maximum_frequency: float
+            The minimum and maximum frequency of the notch
+
+        """
+
+        if 0 < minimum_frequency < maximum_frequency < np.inf:
+            self.minimum_frequency = minimum_frequency
+            self.maximum_frequency = maximum_frequency
+        else:
+            msg = ("Your notch minimum_frequency {} and maximum_frequency {} are invalid"
+                   .format(minimum_frequency, maximum_frequency))
+            raise ValueError(msg)
+
+    def get_idxs(self, frequency_array):
+        """ Get a boolean mask for the frequencies in frequency_array in the notch
+
+        Parameters
+        ----------
+        frequency_array: np.ndarray
+            An array of frequencies
+
+        Returns
+        -------
+        idxs: np.ndarray
+            An array of booleans which are True for frequencies in the notch
+
+        """
+        lower = (frequency_array > self.minimum_frequency)
+        upper = (frequency_array < self.maximum_frequency)
+        return lower & upper
+
+    def check_frequency(self, freq):
+        """ Check if freq is inside the notch
+
+        Parameters
+        ----------
+        freq: float
+            The frequency to check
+
+        Returns
+        -------
+        True/False:
+            If freq inside the notch, return True, else False
+        """
+
+        if self.minimum_frequency < freq < self.maximum_frequency:
+            return True
+        else:
+            return False
+
+
+class NotchList(list):
+    def __init__(self, notch_list):
+        """ A list of notches
+
+        Parameters
+        ----------
+        notch_list: list
+            A list of length-2 tuples of the (max, min) frequency for the
+            notches.
+
+        Raises
+        ------
+        ValueError
+            If the list is malformed.
+        """
+
+        if notch_list is not None:
+            for notch in notch_list:
+                if isinstance(notch, tuple) and len(notch) == 2:
+                    self.append(Notch(*notch))
+                else:
+                    msg = "notch_list {} is malformed".format(notch_list)
+                    raise ValueError(msg)
+
+    def check_frequency(self, freq):
+        """ Check if freq is inside the notch list
+
+        Parameters
+        ----------
+        freq: float
+            The frequency to check
+
+        Returns
+        -------
+        True/False:
+            If freq inside any of the notches, return True, else False
+        """
+
+        for notch in self:
+            if notch.check_frequency(freq):
+                return True
+        return False
diff --git a/test/gw/detector/strain_data_test.py b/test/gw/detector/strain_data_test.py
index a11e347df..9cba74104 100644
--- a/test/gw/detector/strain_data_test.py
+++ b/test/gw/detector/strain_data_test.py
@@ -35,6 +35,49 @@ class TestInterferometerStrainData(unittest.TestCase):
                 np.array_equal(self.ifosd.frequency_mask, [False, True, False])
             )
 
+    def test_frequency_mask_2(self):
+        strain_data = bilby.gw.detector.InterferometerStrainData(
+            minimum_frequency=20, maximum_frequency=512)
+        strain_data.set_from_time_domain_strain(
+            time_domain_strain=np.random.normal(0, 1, 4096),
+            time_array=np.arange(0, 4, 4 / 4096)
+        )
+
+        # Test from init
+        freqs = strain_data.frequency_array[strain_data.frequency_mask]
+        self.assertTrue(all(freqs >= 20))
+        self.assertTrue(all(freqs <= 512))
+
+        # Test from update
+        strain_data.minimum_frequency = 30
+        strain_data.maximum_frequency = 256
+        freqs = strain_data.frequency_array[strain_data.frequency_mask]
+        self.assertTrue(all(freqs >= 30))
+        self.assertTrue(all(freqs <= 256))
+
+    def test_notches_frequency_mask(self):
+        strain_data = bilby.gw.detector.InterferometerStrainData(
+            minimum_frequency=20, maximum_frequency=512, notch_list=[(100, 101)])
+        strain_data.set_from_time_domain_strain(
+            time_domain_strain=np.random.normal(0, 1, 4096),
+            time_array=np.arange(0, 4, 4 / 4096)
+        )
+
+        # Test from init
+        freqs = strain_data.frequency_array[strain_data.frequency_mask]
+        idxs = (freqs > 100) * (freqs < 101)
+        self.assertTrue(len(freqs[idxs]) == 0)
+
+        # Test from setting
+        idxs = (freqs > 200) * (freqs < 201)
+        self.assertTrue(len(freqs[idxs]) > 0)
+        strain_data.notch_list = [(100, 101), (200, 201)]
+        freqs = strain_data.frequency_array[strain_data.frequency_mask]
+        idxs = (freqs > 200) * (freqs < 201)
+        self.assertTrue(len(freqs[idxs]) == 0)
+        idxs = (freqs > 100) * (freqs < 101)
+        self.assertTrue(len(freqs[idxs]) == 0)
+
     def test_set_data_fails(self):
         with mock.patch("bilby.core.utils.create_frequency_series") as m:
             m.return_value = [1, 2, 3]
@@ -316,5 +359,67 @@ class TestInterferometerStrainDataEquals(unittest.TestCase):
         self.assertNotEqual(self.ifosd_1, self.ifosd_2)
 
 
+class TestNotch(unittest.TestCase):
+    def setUp(self):
+        self.minimum_frequency = 20
+        self.maximum_frequency = 1024
+
+    def test_init(self):
+        notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
+        self.assertEqual(notch.minimum_frequency, self.minimum_frequency)
+        self.assertEqual(notch.maximum_frequency, self.maximum_frequency)
+
+    def test_init_fail(self):
+        # Infinite frequency
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.Notch(self.minimum_frequency, np.inf)
+
+        # Negative frequency
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.Notch(-10, 1024)
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.Notch(10, -1024)
+
+        # Ordering
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.Notch(30, 20)
+
+    def test_idxs(self):
+        notch = bilby.gw.detector.strain_data.Notch(self.minimum_frequency, self.maximum_frequency)
+        freqs = np.linspace(0, 2048, 100)
+        idxs = notch.get_idxs(freqs)
+        self.assertEqual(len(idxs), len(freqs))
+        freqs_masked = freqs[idxs]
+        self.assertTrue(all(freqs_masked > notch.minimum_frequency))
+        self.assertTrue(all(freqs_masked < notch.maximum_frequency))
+
+
+class TestNotchList(unittest.TestCase):
+
+    def test_init_single(self):
+        notch_list_of_tuples = [(32, 34)]
+        notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
+        self.assertEqual(len(notch_list), len(notch_list_of_tuples))
+        for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
+            self.assertEqual(notch.minimum_frequency, notch_tuple[0])
+            self.assertEqual(notch.maximum_frequency, notch_tuple[1])
+
+    def test_init_multiple(self):
+        notch_list_of_tuples = [(32, 34), (56, 59)]
+        notch_list = bilby.gw.detector.strain_data.NotchList(notch_list_of_tuples)
+        self.assertEqual(len(notch_list), len(notch_list_of_tuples))
+        for notch, notch_tuple in zip(notch_list, notch_list_of_tuples):
+            self.assertEqual(notch.minimum_frequency, notch_tuple[0])
+            self.assertEqual(notch.maximum_frequency, notch_tuple[1])
+
+    def test_init_fail(self):
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.NotchList([20, 30])
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.NotchList([(30, 20), (20)])
+        with self.assertRaises(ValueError):
+            bilby.gw.detector.strain_data.NotchList([(30, 20, 20)])
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab