From 2c88b04cf2f2665bca9d4397ec0ad63692a56976 Mon Sep 17 00:00:00 2001
From: Kevin Kuns <kevin.kuns@ligo.org>
Date: Tue, 7 Dec 2021 17:37:35 -0500
Subject: [PATCH] fix bug when updating ifo struct in budget run

---
 gwinc/nb.py                  |  2 +-
 test/budgets/test_budgets.py | 35 +++++++++++++++++++++++++++++++++++
 2 files changed, 36 insertions(+), 1 deletion(-)

diff --git a/gwinc/nb.py b/gwinc/nb.py
index 5513888..c0d4958 100644
--- a/gwinc/nb.py
+++ b/gwinc/nb.py
@@ -271,7 +271,7 @@ class Noise(BudgetItem):
                 ifo_hash = ifo.hash(ifo._orig_keys)
                 if ifo_hash != getattr(self, '_ifo_hash', 0):
                     logger.debug("ifo hash change")
-                    kwargs['ifo'] = self.ifo
+                    kwargs['ifo'] = ifo
             self._ifo_hash = ifo_hash
 
         if kwargs:
diff --git a/test/budgets/test_budgets.py b/test/budgets/test_budgets.py
index 3c01773..283bbfd 100644
--- a/test/budgets/test_budgets.py
+++ b/test/budgets/test_budgets.py
@@ -1,7 +1,10 @@
 """
 """
+import numpy as np
 import gwinc
 from gwinc import load_budget
+from copy import deepcopy
+import pytest
 
 
 def test_load(pprint, tpath_join, fpath_join):
@@ -12,3 +15,35 @@ def test_load(pprint, tpath_join, fpath_join):
         fig = trace.plot()
         fig.savefig(tpath_join('budget_{}.pdf'.format(ifo)))
 
+
+@pytest.mark.logic
+@pytest.mark.fast
+def test_update_ifo_struct():
+    """
+    Test that the noise is recalculated when the ifo struct is updated
+    """
+    budget = gwinc.load_budget('CE2silica')
+    tr1 = budget.run()
+    budget.ifo.Suspension.VHCoupling.theta *= 2
+    tr2 = budget.run()
+    assert np.all(
+        tr2.Seismic.SeismicVertical.asd == 2*tr1.Seismic.SeismicVertical.asd)
+
+
+@pytest.mark.logic
+@pytest.mark.fast
+def test_change_ifo_struct():
+    """
+    Test that the noise is recalculated when a new ifo struct is passed to run
+    """
+    budget = gwinc.load_budget('CE2silica')
+    ifo1 = deepcopy(budget.ifo)
+    ifo2 = deepcopy(budget.ifo)
+    ifo2.Suspension.VHCoupling.theta *= 2
+    tr1 = budget.run(ifo=ifo1)
+    tr2 = budget.run(ifo=ifo2)
+    tr3 = budget.run(ifo=ifo1)
+    assert np.all(tr1.asd == tr3.asd)
+    assert np.all(
+        tr2.Seismic.SeismicVertical.asd == 2*tr1.Seismic.SeismicVertical.asd)
+
-- 
GitLab