From ae3e7fb4a3bcb708cd0bc0c44263cb1738fbb1ed Mon Sep 17 00:00:00 2001 From: Moritz <email@moritz-huebner.de> Date: Tue, 13 Nov 2018 15:40:08 +1100 Subject: [PATCH] Implemented `__eq__` method and tests for `InterferometerStrainData` --- bilby/gw/detector.py | 16 ++++++++ test/detector_test.py | 90 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/bilby/gw/detector.py b/bilby/gw/detector.py index d87ce2aaf..dc28bba21 100644 --- a/bilby/gw/detector.py +++ b/bilby/gw/detector.py @@ -240,6 +240,22 @@ class InterferometerStrainData(object): self._time_domain_strain = None self._time_array = None + def __eq__(self, other): + if self.minimum_frequency == other.minimum_frequency \ + and self.maximum_frequency == other.maximum_frequency \ + and self.roll_off == other.roll_off \ + and self.window_factor == other.window_factor \ + and self.sampling_frequency == other.sampling_frequency \ + and self.duration == other.duration \ + and self.start_time == other.start_time \ + and np.array_equal(self.time_array, other.time_array) \ + and np.array_equal(self.frequency_array, other.frequency_array) \ + and np.array_equal(self.frequency_domain_strain, other.frequency_domain_strain) \ + and np.array_equal(self.time_domain_strain, other.time_domain_strain): + return True + return False + + @property def frequency_array(self): """ Frequencies of the data in Hz """ diff --git a/test/detector_test.py b/test/detector_test.py index 59efc53c8..add1d4b1b 100644 --- a/test/detector_test.py +++ b/test/detector_test.py @@ -7,7 +7,6 @@ from mock import MagicMock from mock import patch import numpy as np import scipy.signal.windows -import gwpy import os import logging @@ -613,6 +612,95 @@ class TestInterferometerStrainData(unittest.TestCase): self.ifosd.frequency_domain_strain = np.array([1]) +class TestInterferometerStrainDataEquals(unittest.TestCase): + + def setUp(self): + self.minimum_frequency = 10 + self.maximum_frequency = 20 + self.roll_off = 0.2 + self.sampling_frequency = 100 + self.duration = 2 + self.frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency, + duration=self.duration) + self.strain = self.frequency_array + self.ifosd_1 = bilby.gw.detector.InterferometerStrainData(minimum_frequency=self.minimum_frequency, + maximum_frequency=self.maximum_frequency, + roll_off=self.roll_off) + self.ifosd_2 = bilby.gw.detector.InterferometerStrainData(minimum_frequency=self.minimum_frequency, + maximum_frequency=self.maximum_frequency, + roll_off=self.roll_off) + self.ifosd_1.set_from_frequency_domain_strain(frequency_domain_strain=self.strain, + frequency_array=self.frequency_array) + self.ifosd_2.set_from_frequency_domain_strain(frequency_domain_strain=self.strain, + frequency_array=self.frequency_array) + + def tearDown(self): + del self.minimum_frequency + del self.maximum_frequency + del self.roll_off + del self.sampling_frequency + del self.duration + del self.frequency_array + del self.strain + del self.ifosd_1 + del self.ifosd_2 + + def test_eq_true(self): + self.assertEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_minimum_frequency(self): + self.ifosd_1.minimum_frequency -= 1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_maximum_frequency(self): + self.ifosd_1.maximum_frequency -= 1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_roll_off(self): + self.ifosd_1.roll_off -= 0.1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_window_factor(self): + self.ifosd_1.roll_off -= 0.1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_sampling_frequency(self): + self.ifosd_1.sampling_frequency -= 0.1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_sampling_duration(self): + self.ifosd_1.duration -= 0.1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_start_time(self): + self.ifosd_1.start_time -= 0.1 + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_frequency_array(self): + new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, + duration=self.duration*2) + self.ifosd_1._frequency_array = new_frequency_array + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_frequency_domain_strain(self): + new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, + duration=self.duration*2) + self.ifosd_1._frequency_domain_strain = new_strain + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_time_array(self): + new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2, + duration=self.duration*2) + self.ifosd_1._time_array = new_time_array + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + def test_eq_different_time_domain_strain(self): + new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2, + duration=self.duration*2) + self.ifosd_1._time_domain_strain= new_strain + self.assertNotEqual(self.ifosd_1, self.ifosd_2) + + class TestInterferometerList(unittest.TestCase): def setUp(self): -- GitLab