diff --git a/bilby/gw/calibration.py b/bilby/gw/calibration.py index 1ed90538a15d683d698ed6358197d16ffedfd067..3a5ea077d0fdc3cf9a90e6415e82d15f8731d5df 100644 --- a/bilby/gw/calibration.py +++ b/bilby/gw/calibration.py @@ -49,6 +49,8 @@ class Recalibrate(object): self.params.update({key[len(self.prefix):]: params[key] for key in params if self.prefix in key}) + def __eq__(self, other): + return self.__dict__ == other.__dict__ class CubicSpline(Recalibrate): diff --git a/bilby/gw/detector.py b/bilby/gw/detector.py index dc28bba217e380669e675a1a82a0baec1b69e587..c7586cb86b3c3bc99039af52263cfdc755a0543f 100644 --- a/bilby/gw/detector.py +++ b/bilby/gw/detector.py @@ -107,7 +107,7 @@ class InterferometerList(list): """ if injection_polarizations is None: if waveform_generator is not None: - injection_polarizations =\ + injection_polarizations = \ waveform_generator.frequency_domain_strain(parameters) else: raise ValueError( @@ -255,7 +255,6 @@ class InterferometerStrainData(object): return True return False - @property def frequency_array(self): """ Frequencies of the data in Hz """ @@ -444,7 +443,7 @@ class InterferometerStrainData(object): logger.info( "Low pass filter frequency of {}Hz requested, this is equal" " or greater than the Nyquist frequency so no filter applied" - .format(filter_freq)) + .format(filter_freq)) return logger.debug("Applying low pass filter with filter frequency {}".format(filter_freq)) @@ -825,6 +824,22 @@ class Interferometer(object): minimum_frequency=minimum_frequency, maximum_frequency=maximum_frequency) + def __eq__(self, other): + if self.name == other.name and \ + self.length == other.length and \ + self.latitude == other.latitude and \ + self.longitude == other.longitude and \ + self.elevation == other.elevation and \ + self.xarm_azimuth == other.xarm_azimuth and \ + self.xarm_tilt == other.xarm_tilt and \ + self.yarm_azimuth == other.yarm_azimuth and \ + self.yarm_tilt == other.yarm_tilt and \ + self.power_spectral_density.__eq__(other.power_spectral_density) and \ + self.calibration_model == other.calibration_model and \ + self.strain_data == other.strain_data: + return True + return False + def __repr__(self): return self.__class__.__name__ + '(name=\'{}\', power_spectral_density={}, minimum_frequency={}, ' \ 'maximum_frequency={}, length={}, latitude={}, longitude={}, elevation={}, ' \ @@ -1259,7 +1274,7 @@ class Interferometer(object): if injection_polarizations is None: if waveform_generator is not None: - injection_polarizations =\ + injection_polarizations = \ waveform_generator.frequency_domain_strain(parameters) else: raise ValueError( @@ -1275,7 +1290,7 @@ class Interferometer(object): if not self.strain_data.time_within_data(parameters['geocent_time']): logger.warning( 'Injecting signal outside segment, start_time={}, merger time={}.' - .format(self.strain_data.start_time, parameters['geocent_time'])) + .format(self.strain_data.start_time, parameters['geocent_time'])) signal_ifo = self.get_detector_response(injection_polarizations, parameters) if np.shape(self.frequency_domain_strain).__eq__(np.shape(signal_ifo)): @@ -1654,6 +1669,15 @@ class PowerSpectralDensity(object): self.psd_file = psd_file self.asd_file = asd_file + def __eq__(self, other): + if self.psd_file == other.psd_file \ + and self.asd_file == other.asd_file \ + and np.array_equal(self.frequency_array, other.frequency_array) \ + and np.array_equal(self.psd_array, other.psd_array) \ + and np.array_equal(self.asd_array, other.asd_array): + return True + return False + def __repr__(self): if self.asd_file is not None or self.psd_file is not None: return self.__class__.__name__ + '(psd_file=\'{}\', asd_file=\'{}\')' \ @@ -1775,12 +1799,12 @@ class PowerSpectralDensity(object): @property def asd_file(self): - return self.__asd_file + return self._asd_file @asd_file.setter def asd_file(self, asd_file): asd_file = self.__validate_file_name(file=asd_file) - self.__asd_file = asd_file + self._asd_file = asd_file if asd_file is not None: self.__import_amplitude_spectral_density() self.__check_file_was_asd_file() @@ -1794,12 +1818,12 @@ class PowerSpectralDensity(object): @property def psd_file(self): - return self.__psd_file + return self._psd_file @psd_file.setter def psd_file(self, psd_file): psd_file = self.__validate_file_name(file=psd_file) - self.__psd_file = psd_file + self._psd_file = psd_file if psd_file is not None: self.__import_power_spectral_density() self.__check_file_was_psd_file() @@ -2161,16 +2185,16 @@ def load_data_from_cache_file( frame_duration = float(frame_duration) if frame_name[:4] == 'file': frame_name = frame_name[16:] - if not data_set & (frame_start < segment_start) &\ - (segment_start < frame_start + frame_duration): + if not data_set & (frame_start < segment_start) & \ + (segment_start < frame_start + frame_duration): ifo.set_strain_data_from_frame_file( frame_name, 4096, segment_duration, start_time=segment_start, channel=channel_name, buffer_time=0) data_set = True - if not psd_set & (frame_start < psd_start) &\ - (psd_start + psd_duration < frame_start + frame_duration): - ifo.power_spectral_density =\ + if not psd_set & (frame_start < psd_start) & \ + (psd_start + psd_duration < frame_start + frame_duration): + ifo.power_spectral_density = \ PowerSpectralDensity.from_frame_file( frame_name, psd_start_time=psd_start, psd_duration=psd_duration, diff --git a/test/detector_test.py b/test/detector_test.py index add1d4b1b84e68e3509ec19328c0d30b14ebc6f5..9dd9d691eb765b97ad5272a731697d1859cf0dd8 100644 --- a/test/detector_test.py +++ b/test/detector_test.py @@ -346,6 +346,125 @@ class TestDetector(unittest.TestCase): self.assertEqual(expected, repr(self.ifo)) +class TestInterferometerEquals(unittest.TestCase): + + def setUp(self): + self.name = 'name' + self.power_spectral_density_1 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.power_spectral_density_2 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.minimum_frequency = 10 + self.maximum_frequency = 20 + self.length = 30 + self.latitude = 1 + self.longitude = 2 + self.elevation = 3 + self.xarm_azimuth = 4 + self.yarm_azimuth = 5 + self.xarm_tilt = 0. + self.yarm_tilt = 0. + # noinspection PyTypeChecker + self.duration = 1 + self.sampling_frequency = 200 + self.frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency, + duration=self.duration) + self.strain = self.frequency_array + self.ifo_1 = bilby.gw.detector.Interferometer(name=self.name, + power_spectral_density=self.power_spectral_density_1, + minimum_frequency=self.minimum_frequency, + maximum_frequency=self.maximum_frequency, length=self.length, + latitude=self.latitude, longitude=self.longitude, + elevation=self.elevation, + xarm_azimuth=self.xarm_azimuth, yarm_azimuth=self.yarm_azimuth, + xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt) + self.ifo_2 = bilby.gw.detector.Interferometer(name=self.name, + power_spectral_density=self.power_spectral_density_2, + minimum_frequency=self.minimum_frequency, + maximum_frequency=self.maximum_frequency, length=self.length, + latitude=self.latitude, longitude=self.longitude, + elevation=self.elevation, + xarm_azimuth=self.xarm_azimuth, yarm_azimuth=self.yarm_azimuth, + xarm_tilt=self.xarm_tilt, yarm_tilt=self.yarm_tilt) + self.ifo_1.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array, + frequency_domain_strain=self.strain) + self.ifo_2.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array, + frequency_domain_strain=self.strain) + + def tearDown(self): + del self.name + del self.power_spectral_density_1 + del self.power_spectral_density_2 + del self.minimum_frequency + del self.maximum_frequency + del self.length + del self.latitude + del self.longitude + del self.elevation + del self.xarm_azimuth + del self.yarm_azimuth + del self.xarm_tilt + del self.yarm_tilt + del self.ifo_1 + del self.ifo_2 + del self.sampling_frequency + del self.duration + del self.frequency_array + del self.strain + + def test_eq_true(self): + self.assertEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_psd(self): + self.ifo_1.power_spectral_density.psd_array[0] = 1234 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_minimum_frequency(self): + self.ifo_1.minimum_frequency -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_maximum_frequency(self): + self.ifo_1.minimum_frequency -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_length(self): + self.ifo_1.length -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_latitude(self): + self.ifo_1.latitude -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_longitude(self): + self.ifo_1.longitude -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_elevation(self): + self.ifo_1.elevation -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_xarm_azimuth(self): + self.ifo_1.xarm_azimuth -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_xarmtilt(self): + self.ifo_1.xarm_tilt -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_yarm_azimuth(self): + self.ifo_1.yarm_azimuth -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_yarm_tilt(self): + self.ifo_1.yarm_tilt -= 1 + self.assertNotEqual(self.ifo_1, self.ifo_2) + + def test_eq_false_different_ifo_strain_data(self): + self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, + duration=self.duration*2) + self.ifo_1.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array, + frequency_domain_strain=self.strain) + self.assertNotEqual(self.ifo_1, self.ifo_2) + + class TestInterferometerStrainData(unittest.TestCase): def setUp(self): @@ -1100,5 +1219,61 @@ class TestPowerSpectralDensityWithFiles(unittest.TestCase): self.assertEqual(expected, repr(psd)) +class TestPowerSpectralDensityEquals(unittest.TestCase): + + def setUp(self): + self.psd_from_file_1 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.psd_from_file_2 = bilby.gw.detector.PowerSpectralDensity.from_aligo() + self.frequency_array = np.linspace(1, 100) + self.psd_array = np.linspace(1, 100) + self.psd_from_array_1 = bilby.gw.detector.PowerSpectralDensity. \ + from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array) + self.psd_from_array_2 = bilby.gw.detector.PowerSpectralDensity. \ + from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array) + + def tearDown(self): + del self.psd_from_file_1 + del self.psd_from_file_2 + del self.frequency_array + del self.psd_array + del self.psd_from_array_1 + del self.psd_from_array_2 + + def test_eq_true_from_array(self): + self.assertEqual(self.psd_from_array_1, self.psd_from_array_2) + + def test_eq_true_from_file(self): + self.assertEqual(self.psd_from_file_1, self.psd_from_file_2) + + def test_eq_false_different_psd_file_name(self): + self.psd_from_file_1._psd_file = 'some_other_name' + self.assertNotEqual(self.psd_from_file_1, self.psd_from_file_2) + + def test_eq_false_different_asd_file_name(self): + self.psd_from_file_1._psd_file = None + self.psd_from_file_2._psd_file = None + self.psd_from_file_1._asd_file = 'some_name' + self.psd_from_file_2._asd_file = 'some_other_name' + self.assertNotEqual(self.psd_from_file_1, self.psd_from_file_2) + + def test_eq_false_different_frequency_array(self): + self.psd_from_file_1.frequency_array[0] = 0.5 + self.psd_from_array_1.frequency_array[0] = 0.5 + self.assertNotEqual(self.psd_from_file_1, self.psd_from_file_2) + self.assertNotEqual(self.psd_from_array_1, self.psd_from_array_2) + + def test_eq_false_different_psd(self): + self.psd_from_file_1.psd_array[0] = 0.53544321 + self.psd_from_array_1.psd_array[0] = 0.53544321 + self.assertNotEqual(self.psd_from_file_1, self.psd_from_file_2) + self.assertNotEqual(self.psd_from_array_1, self.psd_from_array_2) + + def test_eq_false_different_asd(self): + self.psd_from_file_1.asd_array[0] = 0.53544321 + self.psd_from_array_1.asd_array[0] = 0.53544321 + self.assertNotEqual(self.psd_from_file_1, self.psd_from_file_2) + self.assertNotEqual(self.psd_from_array_1, self.psd_from_array_2) + + if __name__ == '__main__': unittest.main()