Skip to content
Snippets Groups Projects
Commit 9dd54b8e authored by Colm Talbot's avatar Colm Talbot Committed by Gregory Ashton
Browse files

add raise_error to network

parent 6cf9a859
No related branches found
No related tags found
1 merge request!1041add raise_error to network
...@@ -109,7 +109,13 @@ class InterferometerList(list): ...@@ -109,7 +109,13 @@ class InterferometerList(list):
duration=duration, duration=duration,
start_time=start_time) start_time=start_time)
def inject_signal(self, parameters=None, injection_polarizations=None, waveform_generator=None): def inject_signal(
self,
parameters=None,
injection_polarizations=None,
waveform_generator=None,
raise_error=True,
):
""" Inject a signal into noise in each of the three detectors. """ Inject a signal into noise in each of the three detectors.
Parameters Parameters
...@@ -124,6 +130,9 @@ class InterferometerList(list): ...@@ -124,6 +130,9 @@ class InterferometerList(list):
waveform_generator: bilby.gw.waveform_generator.WaveformGenerator waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
A WaveformGenerator instance using the source model to inject. If A WaveformGenerator instance using the source model to inject. If
`injection_polarizations` is given, this will be ignored. `injection_polarizations` is given, this will be ignored.
raise_error: bool
Whether to raise an error if the injected signal does not fit in
the segment.
Notes Notes
========== ==========
...@@ -148,7 +157,12 @@ class InterferometerList(list): ...@@ -148,7 +157,12 @@ class InterferometerList(list):
all_injection_polarizations = list() all_injection_polarizations = list()
for interferometer in self: for interferometer in self:
all_injection_polarizations.append( all_injection_polarizations.append(
interferometer.inject_signal(parameters=parameters, injection_polarizations=injection_polarizations)) interferometer.inject_signal(
parameters=parameters,
injection_polarizations=injection_polarizations,
raise_error=raise_error,
)
)
return all_injection_polarizations return all_injection_polarizations
......
...@@ -240,8 +240,8 @@ class TestInterferometerList(unittest.TestCase): ...@@ -240,8 +240,8 @@ class TestInterferometerList(unittest.TestCase):
@patch.object(bilby.gw.detector.Interferometer, "inject_signal") @patch.object(bilby.gw.detector.Interferometer, "inject_signal")
def test_inject_signal_with_inj_pol(self, m): def test_inject_signal_with_inj_pol(self, m):
self.ifo_list.inject_signal(injection_polarizations=dict(plus=1)) self.ifo_list.inject_signal(injection_polarizations=dict(plus=1), raise_error=False)
m.assert_called_with(parameters=None, injection_polarizations=dict(plus=1)) m.assert_called_with(parameters=None, injection_polarizations=dict(plus=1), raise_error=False)
self.assertEqual(len(self.ifo_list), m.call_count) self.assertEqual(len(self.ifo_list), m.call_count)
@patch.object(bilby.gw.detector.Interferometer, "inject_signal") @patch.object(bilby.gw.detector.Interferometer, "inject_signal")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment