Skip to content
Snippets Groups Projects
Commit 9dc1b8fb authored by Colm Talbot's avatar Colm Talbot Committed by Moritz Huebner
Browse files

Fix custom interpolant

parent 81fca8e3
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ import types
import subprocess
import multiprocessing
from importlib import import_module
from numbers import Number
import json
import warnings
......@@ -953,34 +954,44 @@ class UnsortedInterp2d(interp2d):
"""
from scipy.interpolate.dfitpack import bispeu
x, y = self._sanitize_inputs(x, y)
out_of_bounds_x = (x < self.x_min) | (x > self.x_max)
out_of_bounds_y = (y < self.y_min) | (y > self.y_max)
bad = out_of_bounds_x | out_of_bounds_y
if isinstance(x, float) and isinstance(y, float):
if isinstance(x, Number) and isinstance(y, Number):
if bad:
output = self.fill_value
ier = 0
else:
output, ier = bispeu(*self.tck, x, y)
output = float(output)
else:
if isinstance(x, np.ndarray):
output = np.zeros_like(x)
x_ = x[~bad]
else:
x_ = x * np.ones_like(y)
if isinstance(y, np.ndarray):
output = np.zeros_like(y)
y_ = y[~bad]
else:
y_ = y * np.ones_like(x)
output = np.empty_like(x)
output[bad] = self.fill_value
output[~bad], ier = bispeu(*self.tck, x_, y_)
output[~bad], ier = bispeu(*self.tck, x[~bad], y[~bad])
if ier == 10:
raise ValueError("Invalid input data")
elif ier:
raise TypeError("An error occurred")
return output
@staticmethod
def _sanitize_inputs(x, y):
if isinstance(x, np.ndarray) and x.size == 1:
x = float(x)
if isinstance(y, np.ndarray) and y.size == 1:
y = float(y)
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
if x.shape != y.shape:
raise ValueError(
"UnsortedInterp2d received unequally shaped arrays"
)
elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray):
y = y * np.ones_like(x)
elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
x = x * np.ones_like(y)
return x, y
# Instantiate the default argument parser at runtime
command_line_args, command_line_parser = set_up_command_line_arguments()
......
......@@ -312,5 +312,50 @@ class TestLatexPlotFormat(unittest.TestCase):
self.assertTrue(os.path.isfile(self.filename))
class TestUnsortedInterp2d(unittest.TestCase):
def setUp(self):
self.xx = np.linspace(0, 1, 10)
self.yy = np.linspace(0, 1, 10)
self.zz = np.random.random((10, 10))
self.interpolant = bilby.core.utils.UnsortedInterp2d(
self.xx, self.yy, self.zz
)
def tearDown(self):
pass
def test_returns_float_for_floats(self):
self.assertIsInstance(self.interpolant(0.5, 0.5), float)
def test_returns_none_for_floats_outside_range(self):
self.assertIsNone(self.interpolant(0.5, -0.5))
self.assertIsNone(self.interpolant(-0.5, 0.5))
def test_returns_float_for_float_and_array(self):
self.assertIsInstance(
self.interpolant(0.5, np.random.random(10)), np.ndarray
)
self.assertIsInstance(
self.interpolant(np.random.random(10), 0.5), np.ndarray
)
self.assertIsInstance(
self.interpolant(np.random.random(10), np.random.random(10)),
np.ndarray
)
def test_raises_for_mismatched_arrays(self):
with self.assertRaises(ValueError):
self.interpolant(
np.random.random(10), np.random.random(20)
)
def test_returns_fill_in_correct_place(self):
x_data = np.random.random(10)
y_data = np.random.random(10)
x_data[3] = -1
self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3]))
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment