From d4c02aa092116b6402fb378f3cbcb294bb0e283d Mon Sep 17 00:00:00 2001
From: Colm Talbot <colm.talbot@ligo.org>
Date: Thu, 11 Jul 2024 15:49:00 +0000
Subject: [PATCH] CI fixes

---
 bilby/core/sampler/dynesty.py              |  6 +-
 bilby/core/utils/calculus.py               | 89 ++++++----------------
 bilby/gw/likelihood/base.py                |  8 +-
 bilby/gw/likelihood/roq.py                 |  6 +-
 examples/core_examples/logo/sample_logo.py |  2 +-
 requirements.txt                           |  5 +-
 test/core/sampler/dynesty_test.py          |  2 +-
 test/core/utils_test.py                    |  2 +-
 8 files changed, 38 insertions(+), 82 deletions(-)

diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py
index 70480c6b8..f1617f1ef 100644
--- a/bilby/core/sampler/dynesty.py
+++ b/bilby/core/sampler/dynesty.py
@@ -691,6 +691,8 @@ class Dynesty(NestedSampler):
         )
         while True:
             self.finalize_sampler_kwargs(sampler_kwargs)
+            if getattr(self.sampler, "added_live", False):
+                self.sampler._remove_live_points()
             self.sampler.run_nested(**sampler_kwargs)
             if self.sampler.ncall == old_ncall:
                 break
@@ -705,8 +707,8 @@ class Dynesty(NestedSampler):
             if last_checkpoint_s > self.check_point_delta_t:
                 self.write_current_state()
                 self.plot_current_state()
-            if getattr(self.sampler, "added_live", False):
-                self.sampler._remove_live_points()
+        if getattr(self.sampler, "added_live", False):
+            self.sampler._remove_live_points()
 
         self.sampler.run_nested(**sampler_kwargs)
         self.write_current_state()
diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py
index 618061f48..ac6fcefcd 100644
--- a/bilby/core/utils/calculus.py
+++ b/bilby/core/utils/calculus.py
@@ -1,7 +1,7 @@
 import math
-from numbers import Number
+
 import numpy as np
-from scipy.interpolate import interp2d
+from scipy.interpolate import RectBivariateSpline
 from scipy.special import logsumexp
 
 from .log import logger
@@ -189,79 +189,34 @@ def logtrapzexp(lnf, dx):
     return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)])
 
 
-class UnsortedInterp2d(interp2d):
-    def __call__(self, x, y, dx=0, dy=0, assume_sorted=False):
-        """Modified version of the interp2d call method.
-
-        This avoids the outer product that is done when two numpy
-        arrays are passed.
-
-        Parameters
-        ==========
-        x: See superclass
-        y: See superclass
-        dx: See superclass
-        dy: See superclass
-        assume_sorted: bool, optional
-            This is just a place holder to prevent a warning.
-            Overwriting this will not do anything
+class BoundedRectBivariateSpline(RectBivariateSpline):
 
-        Returns
-        =======
-        array_like: See superclass
+    def __init__(self, x, y, z, bbox=[None] * 4, kx=3, ky=3, s=0, fill_value=None):
+        self.x_min, self.x_max, self.y_min, self.y_max = bbox
+        if self.x_min is None:
+            self.x_min = min(x)
+        if self.x_max is None:
+            self.x_max = max(x)
+        if self.y_min is None:
+            self.y_min = min(y)
+        if self.y_max is None:
+            self.y_max = max(y)
+        self.fill_value = fill_value
+        super().__init__(x=x, y=y, z=z, bbox=bbox, kx=kx, ky=ky, s=s)
 
-        """
-        from scipy.interpolate.dfitpack import bispeu
-
-        x, y = self._sanitize_inputs(x, y)
+    def __call__(self, x, y, dx=0, dy=0, grid=False):
+        result = super().__call__(x=x, y=y, dx=dx, dy=dy, grid=grid)
         out_of_bounds_x = (x < self.x_min) | (x > self.x_max)
         out_of_bounds_y = (y < self.y_min) | (y > self.y_max)
         bad = out_of_bounds_x | out_of_bounds_y
-        if isinstance(x, Number) and isinstance(y, Number):
+        result[bad] = self.fill_value
+        if result.size == 1:
             if bad:
-                output = self.fill_value
-                ier = 0
+                return self.fill_value
             else:
-                output, ier = bispeu(*self.tck, x, y)
-                output = float(output)
+                return result.item()
         else:
-            output = np.empty_like(x)
-            output[bad] = self.fill_value
-            if np.any(~bad):
-                output[~bad], ier = bispeu(*self.tck, x[~bad], y[~bad])
-            else:
-                ier = 0
-        if ier == 10:
-            raise ValueError("Invalid input data")
-        elif ier:
-            raise TypeError("An error occurred")
-        return output
-
-    @staticmethod
-    def _sanitize_inputs(x, y):
-        if isinstance(x, np.ndarray) and x.size == 1:
-            x = float(x)
-        if isinstance(y, np.ndarray) and y.size == 1:
-            y = float(y)
-        if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
-            original_shapes = (x.shape, y.shape)
-            if x.shape != y.shape:
-                while x.ndim > y.ndim:
-                    y = np.expand_dims(y, -1)
-                while y.ndim > x.ndim:
-                    x = np.expand_dims(x, -1)
-            try:
-                x = x * np.ones(y.shape)
-                y = y * np.ones(x.shape)
-            except ValueError:
-                raise ValueError(
-                    f"UnsortedInterp2d received incompatibly shaped arrays: {original_shapes}"
-                )
-        elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray):
-            y = y * np.ones_like(x)
-        elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
-            x = x * np.ones_like(y)
-        return x, y
+            return result
 
 
 def round_up_to_power_of_two(x):
diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py
index e0a09c1e9..d04b28d82 100644
--- a/bilby/gw/likelihood/base.py
+++ b/bilby/gw/likelihood/base.py
@@ -7,7 +7,7 @@ import numpy as np
 from scipy.special import logsumexp
 
 from ...core.likelihood import Likelihood
-from ...core.utils import logger, UnsortedInterp2d, create_time_series
+from ...core.utils import logger, BoundedRectBivariateSpline, create_time_series
 from ...core.prior import Interped, Prior, Uniform, DeltaFunction
 from ..detector import InterferometerList, get_empty_interferometer, calibration
 from ..prior import BBHPriorDict, Cosmological
@@ -752,7 +752,7 @@ class GravitationalWaveTransient(Likelihood):
             d_inner_h_ref = np.real(d_inner_h_ref)
 
         return self._interp_dist_margd_loglikelihood(
-            d_inner_h_ref, h_inner_h_ref)
+            d_inner_h_ref, h_inner_h_ref, grid=False)
 
     def phase_marginalized_likelihood(self, d_inner_h, h_inner_h):
         d_inner_h = ln_i0(abs(d_inner_h))
@@ -891,9 +891,9 @@ class GravitationalWaveTransient(Likelihood):
                 self._create_lookup_table()
         else:
             self._create_lookup_table()
-        self._interp_dist_margd_loglikelihood = UnsortedInterp2d(
+        self._interp_dist_margd_loglikelihood = BoundedRectBivariateSpline(
             self._d_inner_h_ref_array, self._optimal_snr_squared_ref_array,
-            self._dist_margd_loglikelihood_array, kind='cubic', fill_value=-np.inf)
+            self._dist_margd_loglikelihood_array.T, fill_value=-np.inf)
 
     @property
     def cached_lookup_table_filename(self):
diff --git a/bilby/gw/likelihood/roq.py b/bilby/gw/likelihood/roq.py
index 12d5212f2..b0bac463d 100644
--- a/bilby/gw/likelihood/roq.py
+++ b/bilby/gw/likelihood/roq.py
@@ -1110,11 +1110,11 @@ class ROQGravitationalWaveTransient(GravitationalWaveTransient):
             f_high: float
                 The maximum frequency which must be considered
             """
-            from scipy.integrate import simps
+            from scipy.integrate import simpson
             integrand1 = np.power(freq, -7. / 3) / psd
-            integral1 = simps(integrand1, freq)
+            integral1 = simpson(y=integrand1, x=freq)
             integrand3 = np.power(freq, 2. / 3.) / (psd * integral1)
-            f_3_bar = simps(integrand3, freq)
+            f_3_bar = simpson(y=integrand3, x=freq)
 
             f_high = scaling * f_3_bar**(1 / 3)
 
diff --git a/examples/core_examples/logo/sample_logo.py b/examples/core_examples/logo/sample_logo.py
index 6b5b7225a..a9993d8ec 100644
--- a/examples/core_examples/logo/sample_logo.py
+++ b/examples/core_examples/logo/sample_logo.py
@@ -18,7 +18,7 @@ for letter in ["B", "I", "L", "Y"]:
     img = 1 - io.imread("{}.png".format(letter), as_gray=True)[::-1, :]
     x = np.arange(img.shape[0])
     y = np.arange(img.shape[1])
-    interp = si.interpolate.interp2d(x, y, img.T)
+    interp = si.RectBivariateSpline(x, y, img, kx=1, ky=1)
 
     likelihood = Likelihood(interp)
 
diff --git a/requirements.txt b/requirements.txt
index 05c8c4fad..37bdb21f5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,11 +1,10 @@
 bilby.cython>=0.3.0
-# remove pin after https://git.ligo.org/lscsoft/bilby/-/merge_requests/1368
-dynesty>=2.0.1,<2.1.4
+dynesty>=2.0.1
 emcee
 corner
 numpy
 matplotlib
-# remove pin after https://git.ligo.org/lscsoft/bilby/-/merge_requests/1368
+# see https://github.com/healpy/healpy/pull/953
 scipy>=1.5,<1.14
 pandas
 dill
diff --git a/test/core/sampler/dynesty_test.py b/test/core/sampler/dynesty_test.py
index d33cc2e23..5d1c534a0 100644
--- a/test/core/sampler/dynesty_test.py
+++ b/test/core/sampler/dynesty_test.py
@@ -216,7 +216,7 @@ class TestCustomSampler(unittest.TestCase):
         self.sampler = cls(
             loglikelihood=lambda x: 1,
             prior_transform=lambda x: x,
-            npdim=4,
+            ndim=4,
             live_points=(np.zeros((1000, 4)), np.zeros((1000, 4)), np.zeros(1000)),
             update_interval=None,
             first_update=dict(),
diff --git a/test/core/utils_test.py b/test/core/utils_test.py
index d6b3e8902..ed63916a3 100644
--- a/test/core/utils_test.py
+++ b/test/core/utils_test.py
@@ -320,7 +320,7 @@ class TestUnsortedInterp2d(unittest.TestCase):
         self.xx = np.linspace(0, 1, 10)
         self.yy = np.linspace(0, 1, 10)
         self.zz = np.random.random((10, 10))
-        self.interpolant = bilby.core.utils.UnsortedInterp2d(self.xx, self.yy, self.zz)
+        self.interpolant = bilby.core.utils.BoundedRectBivariateSpline(self.xx, self.yy, self.zz)
 
     def tearDown(self):
         pass
-- 
GitLab