diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 24afd74cb80c5c725e102bcd6cb02c0771a687e2..c466667112e887b22a8617fc6d7bdec71e146457 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -287,20 +287,23 @@ class Interferometer(object): return signal_ifo - def inject_signal(self, parameters=None, injection_polarizations=None, + def inject_signal(self, parameters, injection_polarizations=None, waveform_generator=None): - """ Inject a signal into noise + """ General signal injection method. + Provide the injection parameters and either the injection polarizations + or the waveform generator to inject a signal into the detector. + Defaults to the injection polarizations is both are given. Parameters ---------- parameters: dict Parameters of the injection. - injection_polarizations: dict + injection_polarizations: dict, optional Polarizations of waveform to inject, output of `waveform_generator.frequency_domain_strain()`. If `waveform_generator` is also given, the injection_polarizations will be calculated directly and this argument can be ignored. - waveform_generator: bilby.gw.waveform_generator.WaveformGenerator + waveform_generator: bilby.gw.waveform_generator.WaveformGenerator, optional A WaveformGenerator instance using the source model to inject. If `injection_polarizations` is given, this will be ignored. @@ -313,40 +316,71 @@ class Interferometer(object): Returns ------- injection_polarizations: dict + The injected polarizations. This is the same as the injection_polarizations parameters + if it was passed in. Otherwise it is the return value of waveform_generator.frequency_domain_strain(). + + """ + if injection_polarizations is None and waveform_generator is None: + raise ValueError( + "inject_signal needs one of waveform_generator or " + "injection_polarizations.") + elif injection_polarizations is not None: + self.inject_signal_from_waveform_polarizations(parameters=parameters, + injection_polarizations=injection_polarizations) + elif waveform_generator is not None: + injection_polarizations = self.inject_signal_from_waveform_generator(parameters=parameters, + waveform_generator=waveform_generator) + return injection_polarizations + + def inject_signal_from_waveform_generator(self, parameters, waveform_generator): + """ Inject a signal using a waveform generator and a set of parameters. + Alternative to `inject_signal` and `inject_signal_from_waveform_polarizations` + + Parameters + ---------- + parameters: dict + Parameters of the injection. + waveform_generator: bilby.gw.waveform_generator.WaveformGenerator + A WaveformGenerator instance using the source model to inject. + + Note + ------- + if your signal takes a substantial amount of time to generate, or + you experience buggy behaviour. It is preferable to use the + inject_signal_from_waveform_polarizations() method. + + Returns + ------- + injection_polarizations: dict + The internally generated injection parameters """ + injection_polarizations = \ + waveform_generator.frequency_domain_strain(parameters) + self.inject_signal_from_waveform_polarizations(parameters=parameters, + injection_polarizations=injection_polarizations) + return injection_polarizations + + def inject_signal_from_waveform_polarizations(self, parameters, injection_polarizations): + """ Inject a signal into the detector from a dict of waveform polarizations. + Alternative to `inject_signal` and `inject_signal_from_waveform_generator`. - if injection_polarizations is None: - if waveform_generator is not None: - injection_polarizations = \ - waveform_generator.frequency_domain_strain(parameters) - else: - raise ValueError( - "inject_signal needs one of waveform_generator or " - "injection_polarizations.") - - if injection_polarizations is None: - raise ValueError( - 'Trying to inject signal which is None. The most likely cause' - ' is that waveform_generator.frequency_domain_strain returned' - ' None. This can be caused if, e.g., mass_2 > mass_1.') + Parameters + ---------- + parameters: dict + Parameters of the injection. + injection_polarizations: dict + Polarizations of waveform to inject, output of + `waveform_generator.frequency_domain_strain()`. + """ if not self.strain_data.time_within_data(parameters['geocent_time']): logger.warning( 'Injecting signal outside segment, start_time={}, merger time={}.' .format(self.strain_data.start_time, parameters['geocent_time'])) signal_ifo = self.get_detector_response(injection_polarizations, parameters) - if np.shape(self.strain_data.frequency_domain_strain).__eq__(np.shape(signal_ifo)): - self.strain_data.frequency_domain_strain = \ - signal_ifo + self.strain_data.frequency_domain_strain - else: - logger.info('Injecting into zero noise.') - self.set_strain_data_from_frequency_domain_strain( - signal_ifo, - sampling_frequency=self.strain_data.sampling_frequency, - duration=self.strain_data.duration, - start_time=self.strain_data.start_time) + self.strain_data.frequency_domain_strain += signal_ifo self.meta_data['optimal_SNR'] = ( np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real) @@ -360,8 +394,6 @@ class Interferometer(object): for key in parameters: logger.info(' {} = {}'.format(key, parameters[key])) - return injection_polarizations - @property def amplitude_spectral_density_array(self): """ Returns the amplitude spectral density (ASD) given we know a power spectral denstiy (PSD) diff --git a/test/detector_test.py b/test/detector_test.py index f63ee5baa4cbdb9fe25998c016035f67a06da4cb..2f225c719ad4827bcea861eef7ea23769dbb0d85 100644 --- a/test/detector_test.py +++ b/test/detector_test.py @@ -1,6 +1,7 @@ from __future__ import absolute_import import bilby +import inspect import unittest import mock from mock import MagicMock @@ -261,6 +262,17 @@ class TestInterferometer(unittest.TestCase): self.ifo.strain_data.set_from_frequency_domain_strain( np.linspace(0, 4096, 4097), sampling_frequency=4096, duration=2) self.outdir = 'outdir' + + self.injection_polarizations = dict() + np.random.seed(42) + self.injection_polarizations['plus'] = np.random.random(4097) + self.injection_polarizations['cross'] = np.random.random(4097) + + self.waveform_generator = MagicMock() + self.wg_polarizations = dict(plus=np.random.random(4097), cross=np.random.random(4097)) + self.waveform_generator.frequency_domain_strain = lambda _: self.wg_polarizations + self.parameters = dict(ra=0., dec=0., geocent_time=0., psi=0.) + bilby.core.utils.check_directory_exists_and_if_not_mkdir(self.outdir) def tearDown(self): @@ -277,6 +289,10 @@ class TestInterferometer(unittest.TestCase): del self.xarm_tilt del self.yarm_tilt del self.ifo + del self.injection_polarizations + del self.wg_polarizations + del self.waveform_generator + del self.parameters rmtree(self.outdir) def test_name_setting(self): @@ -342,7 +358,86 @@ class TestInterferometer(unittest.TestCase): parameters=dict(ra=0, dec=0, geocent_time=0, psi=0)) self.assertTrue(np.array_equal(response, (plus + cross) * self.ifo.frequency_mask * np.exp(-0j))) - def test_inject_signal_no_waveform_polarizations(self): + def test_inject_signal_from_waveform_polarizations_correct_injection(self): + original_strain = self.ifo.strain_data.frequency_domain_strain + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) + + def test_inject_signal_from_waveform_polarizations_meta_data(self): + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + signal_ifo_expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + self.assertAlmostEqual(self.ifo.optimal_snr_squared(signal=signal_ifo_expected).real, + self.ifo.meta_data['optimal_SNR']**2, 10) + self.assertAlmostEqual(self.ifo.matched_filter_snr(signal=signal_ifo_expected), + self.ifo.meta_data['matched_filter_SNR'], 10) + self.assertDictEqual(self.parameters, + self.ifo.meta_data['parameters']) + + def test_inject_signal_from_waveform_polarizations_incorrect_length(self): + self.injection_polarizations['plus'] = np.random.random(1000) + self.injection_polarizations['cross'] = np.random.random(1000) + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + with self.assertRaises(ValueError): + self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + + @patch.object(bilby.core.utils.logger, 'warning') + def test_inject_signal_outside_segment_logs_warning(self, m): + self.parameters['geocent_time'] = 24345. + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + self.assertTrue(m.called) + + def test_inject_signal_from_waveform_generator_correct_return_value(self): + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + returned_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters, + waveform_generator=self.waveform_generator) + self.assertTrue(np.array_equal(self.wg_polarizations['plus'], returned_polarizations['plus'])) + self.assertTrue(np.array_equal(self.wg_polarizations['cross'], returned_polarizations['cross'])) + + @patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_generator') + def test_inject_signal_with_waveform_generator_correct_call(self, m): + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + _ = self.ifo.inject_signal(parameters=self.parameters, + waveform_generator=self.waveform_generator) + m.assert_called_with(parameters=self.parameters, + waveform_generator=self.waveform_generator) + + def test_inject_signal_from_waveform_generator_correct_injection(self): + original_strain = self.ifo.strain_data.frequency_domain_strain + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + injection_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters, + waveform_generator=self.waveform_generator) + expected = injection_polarizations['plus'] + injection_polarizations['cross'] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) + + def test_inject_signal_with_injection_polarizations(self): + original_strain = self.ifo.strain_data.frequency_domain_strain + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + self.ifo.inject_signal(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain + self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain)) + + @patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_polarizations') + def test_inject_signal_with_injection_polarizations_and_waveform_generator(self, m): + self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross'] + _ = self.ifo.inject_signal(parameters=self.parameters, + waveform_generator=self.waveform_generator, + injection_polarizations=self.injection_polarizations) + m.assert_called_with(parameters=self.parameters, + injection_polarizations=self.injection_polarizations) + with self.assertRaises(ValueError): + m.assert_called_with(parameters=self.parameters, + injection_polarizations=self.wg_polarizations) + + def test_inject_signal_raises_value_error(self): with self.assertRaises(ValueError): self.ifo.inject_signal(injection_polarizations=None, parameters=None) @@ -517,8 +612,8 @@ class TestInterferometerEquals(unittest.TestCase): self.assertNotEqual(self.ifo_1, self.ifo_2) def test_eq_false_different_ifo_strain_data(self): - self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, - duration=self.duration*2) + self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2, + duration=self.duration * 2) self.ifo_1.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array, frequency_domain_strain=self.strain) self.assertNotEqual(self.ifo_1, self.ifo_2) @@ -766,27 +861,27 @@ class TestInterferometerStrainDataEquals(unittest.TestCase): self.assertNotEqual(self.ifosd_1, self.ifosd_2) def test_eq_different_frequency_array(self): - new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, - duration=self.duration*2) + new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2, + duration=self.duration * 2) self.ifosd_1.frequency_array = new_frequency_array self.assertNotEqual(self.ifosd_1, self.ifosd_2) def test_eq_different_frequency_domain_strain(self): - new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2, - duration=self.duration*2) + new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2, + duration=self.duration * 2) self.ifosd_1._frequency_domain_strain = new_strain self.assertNotEqual(self.ifosd_1, self.ifosd_2) def test_eq_different_time_array(self): - new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2, - duration=self.duration*2) + new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2, + duration=self.duration * 2) self.ifosd_1.time_array = new_time_array self.assertNotEqual(self.ifosd_1, self.ifosd_2) def test_eq_different_time_domain_strain(self): - new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2, - duration=self.duration*2) - self.ifosd_1._time_domain_strain= new_strain + new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2, + duration=self.duration * 2) + self.ifosd_1._time_domain_strain = new_strain self.assertNotEqual(self.ifosd_1, self.ifosd_2) @@ -1250,9 +1345,9 @@ class TestPowerSpectralDensityEquals(unittest.TestCase): self.frequency_array = np.linspace(1, 100) self.psd_array = np.linspace(1, 100) self.psd_from_array_1 = bilby.gw.detector.PowerSpectralDensity. \ - from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array) + from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array) self.psd_from_array_2 = bilby.gw.detector.PowerSpectralDensity. \ - from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array) + from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array) def tearDown(self): del self.psd_from_file_1