diff --git a/bilby/core/utils.py b/bilby/core/utils.py index 5a8d17844a0ad2f8eefd3c50cf2ee6b442682247..54193d23868b54c4596b495f399e41e87c62ddb1 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -12,6 +12,7 @@ import types import subprocess import multiprocessing from importlib import import_module +from numbers import Number import json import warnings @@ -953,34 +954,44 @@ class UnsortedInterp2d(interp2d): """ from scipy.interpolate.dfitpack import bispeu + x, y = self._sanitize_inputs(x, y) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) bad = out_of_bounds_x | out_of_bounds_y - if isinstance(x, float) and isinstance(y, float): + if isinstance(x, Number) and isinstance(y, Number): if bad: output = self.fill_value ier = 0 else: output, ier = bispeu(*self.tck, x, y) + output = float(output) else: - if isinstance(x, np.ndarray): - output = np.zeros_like(x) - x_ = x[~bad] - else: - x_ = x * np.ones_like(y) - if isinstance(y, np.ndarray): - output = np.zeros_like(y) - y_ = y[~bad] - else: - y_ = y * np.ones_like(x) + output = np.empty_like(x) output[bad] = self.fill_value - output[~bad], ier = bispeu(*self.tck, x_, y_) + output[~bad], ier = bispeu(*self.tck, x[~bad], y[~bad]) if ier == 10: raise ValueError("Invalid input data") elif ier: raise TypeError("An error occurred") return output + @staticmethod + def _sanitize_inputs(x, y): + if isinstance(x, np.ndarray) and x.size == 1: + x = float(x) + if isinstance(y, np.ndarray) and y.size == 1: + y = float(y) + if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): + if x.shape != y.shape: + raise ValueError( + "UnsortedInterp2d received unequally shaped arrays" + ) + elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray): + y = y * np.ones_like(x) + elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray): + x = x * np.ones_like(y) + return x, y + # Instantiate the default argument parser at runtime command_line_args, command_line_parser = set_up_command_line_arguments() diff --git a/test/core/utils_test.py b/test/core/utils_test.py index 8eee5f9dd16c8ae1966f9f9d9d2b93e1f3e4d480..774311bda5559f715d5a73f1d181f19869db7325 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -312,5 +312,50 @@ class TestLatexPlotFormat(unittest.TestCase): self.assertTrue(os.path.isfile(self.filename)) +class TestUnsortedInterp2d(unittest.TestCase): + + def setUp(self): + self.xx = np.linspace(0, 1, 10) + self.yy = np.linspace(0, 1, 10) + self.zz = np.random.random((10, 10)) + self.interpolant = bilby.core.utils.UnsortedInterp2d( + self.xx, self.yy, self.zz + ) + + def tearDown(self): + pass + + def test_returns_float_for_floats(self): + self.assertIsInstance(self.interpolant(0.5, 0.5), float) + + def test_returns_none_for_floats_outside_range(self): + self.assertIsNone(self.interpolant(0.5, -0.5)) + self.assertIsNone(self.interpolant(-0.5, 0.5)) + + def test_returns_float_for_float_and_array(self): + self.assertIsInstance( + self.interpolant(0.5, np.random.random(10)), np.ndarray + ) + self.assertIsInstance( + self.interpolant(np.random.random(10), 0.5), np.ndarray + ) + self.assertIsInstance( + self.interpolant(np.random.random(10), np.random.random(10)), + np.ndarray + ) + + def test_raises_for_mismatched_arrays(self): + with self.assertRaises(ValueError): + self.interpolant( + np.random.random(10), np.random.random(20) + ) + + def test_returns_fill_in_correct_place(self): + x_data = np.random.random(10) + y_data = np.random.random(10) + x_data[3] = -1 + self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) + + if __name__ == "__main__": unittest.main()