From 9dd54b8ebd08db8f6f69f747ad24e1226c21a230 Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Tue, 11 Jan 2022 13:11:07 +0000
Subject: [PATCH] add raise_error to network

---
 bilby/gw/detector/networks.py     | 18 ++++++++++++++++--
 test/gw/detector/networks_test.py |  4 ++--
 2 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/bilby/gw/detector/networks.py b/bilby/gw/detector/networks.py
index d354f2bf4..4370219c1 100644
--- a/bilby/gw/detector/networks.py
+++ b/bilby/gw/detector/networks.py
@@ -109,7 +109,13 @@ class InterferometerList(list):
                                                            duration=duration,
                                                            start_time=start_time)
 
-    def inject_signal(self, parameters=None, injection_polarizations=None, waveform_generator=None):
+    def inject_signal(
+        self,
+        parameters=None,
+        injection_polarizations=None,
+        waveform_generator=None,
+        raise_error=True,
+    ):
         """ Inject a signal into noise in each of the three detectors.
 
         Parameters
@@ -124,6 +130,9 @@ class InterferometerList(list):
         waveform_generator: bilby.gw.waveform_generator.WaveformGenerator
             A WaveformGenerator instance using the source model to inject. If
             `injection_polarizations` is given, this will be ignored.
+        raise_error: bool
+            Whether to raise an error if the injected signal does not fit in
+            the segment.
 
         Notes
         ==========
@@ -148,7 +157,12 @@ class InterferometerList(list):
         all_injection_polarizations = list()
         for interferometer in self:
             all_injection_polarizations.append(
-                interferometer.inject_signal(parameters=parameters, injection_polarizations=injection_polarizations))
+                interferometer.inject_signal(
+                    parameters=parameters,
+                    injection_polarizations=injection_polarizations,
+                    raise_error=raise_error,
+                )
+            )
 
         return all_injection_polarizations
 
diff --git a/test/gw/detector/networks_test.py b/test/gw/detector/networks_test.py
index 4484cfb12..2ad4060e5 100644
--- a/test/gw/detector/networks_test.py
+++ b/test/gw/detector/networks_test.py
@@ -240,8 +240,8 @@ class TestInterferometerList(unittest.TestCase):
 
     @patch.object(bilby.gw.detector.Interferometer, "inject_signal")
     def test_inject_signal_with_inj_pol(self, m):
-        self.ifo_list.inject_signal(injection_polarizations=dict(plus=1))
-        m.assert_called_with(parameters=None, injection_polarizations=dict(plus=1))
+        self.ifo_list.inject_signal(injection_polarizations=dict(plus=1), raise_error=False)
+        m.assert_called_with(parameters=None, injection_polarizations=dict(plus=1), raise_error=False)
         self.assertEqual(len(self.ifo_list), m.call_count)
 
     @patch.object(bilby.gw.detector.Interferometer, "inject_signal")
-- 
GitLab