Skip to content
Snippets Groups Projects
Commit a0d16709 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'modify_inject_signal' into 'master'

Modify inject signal

See merge request lscsoft/bilby!508
parents d525afba 3660625a
No related branches found
No related tags found
1 merge request!508Modify inject signal
Pipeline #65366 passed
......@@ -287,20 +287,23 @@ class Interferometer(object):
return signal_ifo
def inject_signal(self, parameters=None, injection_polarizations=None,
def inject_signal(self, parameters, injection_polarizations=None,
waveform_generator=None):
""" Inject a signal into noise
""" General signal injection method.
Provide the injection parameters and either the injection polarizations
or the waveform generator to inject a signal into the detector.
Defaults to the injection polarizations is both are given.
Parameters
----------
parameters: dict
Parameters of the injection.
injection_polarizations: dict
injection_polarizations: dict, optional
Polarizations of waveform to inject, output of
`waveform_generator.frequency_domain_strain()`. If
`waveform_generator` is also given, the injection_polarizations will
be calculated directly and this argument can be ignored.
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator, optional
A WaveformGenerator instance using the source model to inject. If
`injection_polarizations` is given, this will be ignored.
......@@ -313,40 +316,71 @@ class Interferometer(object):
Returns
-------
injection_polarizations: dict
The injected polarizations. This is the same as the injection_polarizations parameters
if it was passed in. Otherwise it is the return value of waveform_generator.frequency_domain_strain().
"""
if injection_polarizations is None and waveform_generator is None:
raise ValueError(
"inject_signal needs one of waveform_generator or "
"injection_polarizations.")
elif injection_polarizations is not None:
self.inject_signal_from_waveform_polarizations(parameters=parameters,
injection_polarizations=injection_polarizations)
elif waveform_generator is not None:
injection_polarizations = self.inject_signal_from_waveform_generator(parameters=parameters,
waveform_generator=waveform_generator)
return injection_polarizations
def inject_signal_from_waveform_generator(self, parameters, waveform_generator):
""" Inject a signal using a waveform generator and a set of parameters.
Alternative to `inject_signal` and `inject_signal_from_waveform_polarizations`
Parameters
----------
parameters: dict
Parameters of the injection.
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
A WaveformGenerator instance using the source model to inject.
Note
-------
if your signal takes a substantial amount of time to generate, or
you experience buggy behaviour. It is preferable to use the
inject_signal_from_waveform_polarizations() method.
Returns
-------
injection_polarizations: dict
The internally generated injection parameters
"""
injection_polarizations = \
waveform_generator.frequency_domain_strain(parameters)
self.inject_signal_from_waveform_polarizations(parameters=parameters,
injection_polarizations=injection_polarizations)
return injection_polarizations
def inject_signal_from_waveform_polarizations(self, parameters, injection_polarizations):
""" Inject a signal into the detector from a dict of waveform polarizations.
Alternative to `inject_signal` and `inject_signal_from_waveform_generator`.
if injection_polarizations is None:
if waveform_generator is not None:
injection_polarizations = \
waveform_generator.frequency_domain_strain(parameters)
else:
raise ValueError(
"inject_signal needs one of waveform_generator or "
"injection_polarizations.")
if injection_polarizations is None:
raise ValueError(
'Trying to inject signal which is None. The most likely cause'
' is that waveform_generator.frequency_domain_strain returned'
' None. This can be caused if, e.g., mass_2 > mass_1.')
Parameters
----------
parameters: dict
Parameters of the injection.
injection_polarizations: dict
Polarizations of waveform to inject, output of
`waveform_generator.frequency_domain_strain()`.
"""
if not self.strain_data.time_within_data(parameters['geocent_time']):
logger.warning(
'Injecting signal outside segment, start_time={}, merger time={}.'
.format(self.strain_data.start_time, parameters['geocent_time']))
signal_ifo = self.get_detector_response(injection_polarizations, parameters)
if np.shape(self.strain_data.frequency_domain_strain).__eq__(np.shape(signal_ifo)):
self.strain_data.frequency_domain_strain = \
signal_ifo + self.strain_data.frequency_domain_strain
else:
logger.info('Injecting into zero noise.')
self.set_strain_data_from_frequency_domain_strain(
signal_ifo,
sampling_frequency=self.strain_data.sampling_frequency,
duration=self.strain_data.duration,
start_time=self.strain_data.start_time)
self.strain_data.frequency_domain_strain += signal_ifo
self.meta_data['optimal_SNR'] = (
np.sqrt(self.optimal_snr_squared(signal=signal_ifo)).real)
......@@ -360,8 +394,6 @@ class Interferometer(object):
for key in parameters:
logger.info(' {} = {}'.format(key, parameters[key]))
return injection_polarizations
@property
def amplitude_spectral_density_array(self):
""" Returns the amplitude spectral density (ASD) given we know a power spectral denstiy (PSD)
......
from __future__ import absolute_import
import bilby
import inspect
import unittest
import mock
from mock import MagicMock
......@@ -261,6 +262,17 @@ class TestInterferometer(unittest.TestCase):
self.ifo.strain_data.set_from_frequency_domain_strain(
np.linspace(0, 4096, 4097), sampling_frequency=4096, duration=2)
self.outdir = 'outdir'
self.injection_polarizations = dict()
np.random.seed(42)
self.injection_polarizations['plus'] = np.random.random(4097)
self.injection_polarizations['cross'] = np.random.random(4097)
self.waveform_generator = MagicMock()
self.wg_polarizations = dict(plus=np.random.random(4097), cross=np.random.random(4097))
self.waveform_generator.frequency_domain_strain = lambda _: self.wg_polarizations
self.parameters = dict(ra=0., dec=0., geocent_time=0., psi=0.)
bilby.core.utils.check_directory_exists_and_if_not_mkdir(self.outdir)
def tearDown(self):
......@@ -277,6 +289,10 @@ class TestInterferometer(unittest.TestCase):
del self.xarm_tilt
del self.yarm_tilt
del self.ifo
del self.injection_polarizations
del self.wg_polarizations
del self.waveform_generator
del self.parameters
rmtree(self.outdir)
def test_name_setting(self):
......@@ -342,7 +358,86 @@ class TestInterferometer(unittest.TestCase):
parameters=dict(ra=0, dec=0, geocent_time=0, psi=0))
self.assertTrue(np.array_equal(response, (plus + cross) * self.ifo.frequency_mask * np.exp(-0j)))
def test_inject_signal_no_waveform_polarizations(self):
def test_inject_signal_from_waveform_polarizations_correct_injection(self):
original_strain = self.ifo.strain_data.frequency_domain_strain
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain
self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
def test_inject_signal_from_waveform_polarizations_meta_data(self):
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
signal_ifo_expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross']
self.assertAlmostEqual(self.ifo.optimal_snr_squared(signal=signal_ifo_expected).real,
self.ifo.meta_data['optimal_SNR']**2, 10)
self.assertAlmostEqual(self.ifo.matched_filter_snr(signal=signal_ifo_expected),
self.ifo.meta_data['matched_filter_SNR'], 10)
self.assertDictEqual(self.parameters,
self.ifo.meta_data['parameters'])
def test_inject_signal_from_waveform_polarizations_incorrect_length(self):
self.injection_polarizations['plus'] = np.random.random(1000)
self.injection_polarizations['cross'] = np.random.random(1000)
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
with self.assertRaises(ValueError):
self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
@patch.object(bilby.core.utils.logger, 'warning')
def test_inject_signal_outside_segment_logs_warning(self, m):
self.parameters['geocent_time'] = 24345.
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
self.ifo.inject_signal_from_waveform_polarizations(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
self.assertTrue(m.called)
def test_inject_signal_from_waveform_generator_correct_return_value(self):
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
returned_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters,
waveform_generator=self.waveform_generator)
self.assertTrue(np.array_equal(self.wg_polarizations['plus'], returned_polarizations['plus']))
self.assertTrue(np.array_equal(self.wg_polarizations['cross'], returned_polarizations['cross']))
@patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_generator')
def test_inject_signal_with_waveform_generator_correct_call(self, m):
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
_ = self.ifo.inject_signal(parameters=self.parameters,
waveform_generator=self.waveform_generator)
m.assert_called_with(parameters=self.parameters,
waveform_generator=self.waveform_generator)
def test_inject_signal_from_waveform_generator_correct_injection(self):
original_strain = self.ifo.strain_data.frequency_domain_strain
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
injection_polarizations = self.ifo.inject_signal_from_waveform_generator(parameters=self.parameters,
waveform_generator=self.waveform_generator)
expected = injection_polarizations['plus'] + injection_polarizations['cross'] + original_strain
self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
def test_inject_signal_with_injection_polarizations(self):
original_strain = self.ifo.strain_data.frequency_domain_strain
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
self.ifo.inject_signal(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
expected = self.injection_polarizations['plus'] + self.injection_polarizations['cross'] + original_strain
self.assertTrue(np.array_equal(expected, self.ifo.strain_data._frequency_domain_strain))
@patch.object(bilby.gw.detector.Interferometer, 'inject_signal_from_waveform_polarizations')
def test_inject_signal_with_injection_polarizations_and_waveform_generator(self, m):
self.ifo.get_detector_response = lambda x, params: x['plus'] + x['cross']
_ = self.ifo.inject_signal(parameters=self.parameters,
waveform_generator=self.waveform_generator,
injection_polarizations=self.injection_polarizations)
m.assert_called_with(parameters=self.parameters,
injection_polarizations=self.injection_polarizations)
with self.assertRaises(ValueError):
m.assert_called_with(parameters=self.parameters,
injection_polarizations=self.wg_polarizations)
def test_inject_signal_raises_value_error(self):
with self.assertRaises(ValueError):
self.ifo.inject_signal(injection_polarizations=None, parameters=None)
......@@ -517,8 +612,8 @@ class TestInterferometerEquals(unittest.TestCase):
self.assertNotEqual(self.ifo_1, self.ifo_2)
def test_eq_false_different_ifo_strain_data(self):
self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
duration=self.duration*2)
self.strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
duration=self.duration * 2)
self.ifo_1.set_strain_data_from_frequency_domain_strain(frequency_array=self.frequency_array,
frequency_domain_strain=self.strain)
self.assertNotEqual(self.ifo_1, self.ifo_2)
......@@ -766,27 +861,27 @@ class TestInterferometerStrainDataEquals(unittest.TestCase):
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
def test_eq_different_frequency_array(self):
new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
duration=self.duration*2)
new_frequency_array = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
duration=self.duration * 2)
self.ifosd_1.frequency_array = new_frequency_array
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
def test_eq_different_frequency_domain_strain(self):
new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency/2,
duration=self.duration*2)
new_strain = bilby.utils.create_frequency_series(sampling_frequency=self.sampling_frequency / 2,
duration=self.duration * 2)
self.ifosd_1._frequency_domain_strain = new_strain
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
def test_eq_different_time_array(self):
new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2,
duration=self.duration*2)
new_time_array = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2,
duration=self.duration * 2)
self.ifosd_1.time_array = new_time_array
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
def test_eq_different_time_domain_strain(self):
new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency/2,
duration=self.duration*2)
self.ifosd_1._time_domain_strain= new_strain
new_strain = bilby.utils.create_time_series(sampling_frequency=self.sampling_frequency / 2,
duration=self.duration * 2)
self.ifosd_1._time_domain_strain = new_strain
self.assertNotEqual(self.ifosd_1, self.ifosd_2)
......@@ -1250,9 +1345,9 @@ class TestPowerSpectralDensityEquals(unittest.TestCase):
self.frequency_array = np.linspace(1, 100)
self.psd_array = np.linspace(1, 100)
self.psd_from_array_1 = bilby.gw.detector.PowerSpectralDensity. \
from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array)
from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array)
self.psd_from_array_2 = bilby.gw.detector.PowerSpectralDensity. \
from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array= self.psd_array)
from_power_spectral_density_array(frequency_array=self.frequency_array, psd_array=self.psd_array)
def tearDown(self):
del self.psd_from_file_1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment