diff --git a/bilby/core/grid.py b/bilby/core/grid.py index 8562e6d67b3afbaa863e85fe8125e62248a7bc41..8e264872d84fa2ef1daa9fc9a8b0d51c9de398a0 100644 --- a/bilby/core/grid.py +++ b/bilby/core/grid.py @@ -191,8 +191,10 @@ class Grid(object): places = self.sample_points[name] if len(places) > 1: + dx = np.diff(places) out = np.apply_along_axis( - logtrapzexp, axis, log_array, places[1] - places[0]) + logtrapzexp, axis, log_array, dx + ) else: # no marginalisation required, just remove the singleton dimension z = log_array.shape diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py index fbd64f9f012e8983a20873f88c50db1953a0973c..ef7af61dec6a139e662cf53438cd07e2aaba6255 100644 --- a/bilby/core/utils/calculus.py +++ b/bilby/core/utils/calculus.py @@ -6,8 +6,16 @@ from scipy.special import logsumexp from .logger import logger -def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3, - epsscale=0.5, nonfixedidx=None): +def derivatives( + vals, + func, + releps=1e-3, + abseps=None, + mineps=1e-9, + reltol=1e-3, + epsscale=0.5, + nonfixedidx=None, +): """ Calculate the partial derivatives of a function at a set of values. The derivatives are calculated using the central difference, using an iterative @@ -54,19 +62,19 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3, grads = np.zeros(len(nonfixedidx)) # maximum number of times the gradient can change sign - flipflopmax = 10. + flipflopmax = 10.0 # set steps if abseps is None: if isinstance(releps, float): eps = np.abs(vals) * releps - eps[eps == 0.] = releps # if any values are zero set eps to releps + eps[eps == 0.0] = releps # if any values are zero set eps to releps teps = releps * np.ones(len(vals)) elif isinstance(releps, (list, np.ndarray)): if len(releps) != len(vals): raise ValueError("Problem with input relative step sizes") eps = np.multiply(np.abs(vals), releps) - eps[eps == 0.] = np.array(releps)[eps == 0.] + eps[eps == 0.0] = np.array(releps)[eps == 0.0] teps = releps else: raise RuntimeError("Relative step sizes are not a recognised type!") @@ -107,8 +115,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3, cureps *= epsscale if cureps < mineps or flipflop > flipflopmax: # if no convergence set flat derivative (TODO: check if there is a better thing to do instead) - logger.warning("Derivative calculation did not converge: setting flat derivative.") - grads[count] = 0. + logger.warning( + "Derivative calculation did not converge: setting flat derivative." + ) + grads[count] = 0.0 break leps *= epsscale @@ -122,10 +132,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3, break # check whether previous diff and current diff are the same within reltol - rat = (cdiff / cdiffnew) - if np.isfinite(rat) and rat > 0.: + rat = cdiff / cdiffnew + if np.isfinite(rat) and rat > 0.0: # gradient has not changed sign - if np.abs(1. - rat) < reltol: + if np.abs(1.0 - rat) < reltol: grads[count] = cdiffnew break else: @@ -143,7 +153,7 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3, def logtrapzexp(lnf, dx): """ - Perform trapezium rule integration for the logarithm of a function on a regular grid. + Perform trapezium rule integration for the logarithm of a function on a grid. Parameters ========== @@ -157,13 +167,30 @@ def logtrapzexp(lnf, dx): ======= The natural logarithm of the area under the function. """ - return np.log(dx / 2.) + logsumexp([logsumexp(lnf[:-1]), logsumexp(lnf[1:])]) + lnfdx1 = lnf[:-1] + lnfdx2 = lnf[1:] + if isinstance(dx, (int, float)): + C = np.log(dx / 2.0) + elif isinstance(dx, (list, np.ndarray)): + if len(dx) != len(lnf) - 1: + raise ValueError( + "Step size array must have length one less than the function length" + ) + + lndx = np.log(dx) + lnfdx1 = lnfdx1.copy() + lndx + lnfdx2 = lnfdx2.copy() + lndx + C = -np.log(2.0) + else: + raise TypeError("Step size must be a single value or array-like") + + return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)]) -class UnsortedInterp2d(interp2d): +class UnsortedInterp2d(interp2d): def __call__(self, x, y, dx=0, dy=0, assume_sorted=False): - """ Modified version of the interp2d call method. + """Modified version of the interp2d call method. This avoids the outer product that is done when two numpy arrays are passed. @@ -184,6 +211,7 @@ class UnsortedInterp2d(interp2d): """ from scipy.interpolate.dfitpack import bispeu + x, y = self._sanitize_inputs(x, y) out_of_bounds_x = (x < self.x_min) | (x > self.x_max) out_of_bounds_y = (y < self.y_min) | (y > self.y_max) @@ -216,9 +244,7 @@ class UnsortedInterp2d(interp2d): y = float(y) if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): if x.shape != y.shape: - raise ValueError( - "UnsortedInterp2d received unequally shaped arrays" - ) + raise ValueError("UnsortedInterp2d received unequally shaped arrays") elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray): y = y * np.ones_like(x) elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray): diff --git a/test/core/utils_test.py b/test/core/utils_test.py index 774311bda5559f715d5a73f1d181f19869db7325..d6b3e8902dac7c0ca77275a37a43e24b8e43d4e8 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -270,6 +270,7 @@ class TestLatexPlotFormat(unittest.TestCase): fig, ax = plt.subplots() ax.plot(self.x, self.y) fig.savefig(self.filename) + plot() self.assertTrue(os.path.isfile(self.filename)) @@ -279,6 +280,7 @@ class TestLatexPlotFormat(unittest.TestCase): fig, ax = plt.subplots() ax.plot(self.x, self.y) fig.savefig(self.filename) + plot(BILBY_MATHDEFAULT=1) self.assertTrue(os.path.isfile(self.filename)) @@ -288,6 +290,7 @@ class TestLatexPlotFormat(unittest.TestCase): fig, ax = plt.subplots() ax.plot(self.x, self.y) fig.savefig(self.filename) + plot(BILBY_MATHDEFAULT=0) self.assertTrue(os.path.isfile(self.filename)) @@ -313,14 +316,11 @@ class TestLatexPlotFormat(unittest.TestCase): class TestUnsortedInterp2d(unittest.TestCase): - def setUp(self): self.xx = np.linspace(0, 1, 10) self.yy = np.linspace(0, 1, 10) self.zz = np.random.random((10, 10)) - self.interpolant = bilby.core.utils.UnsortedInterp2d( - self.xx, self.yy, self.zz - ) + self.interpolant = bilby.core.utils.UnsortedInterp2d(self.xx, self.yy, self.zz) def tearDown(self): pass @@ -333,22 +333,15 @@ class TestUnsortedInterp2d(unittest.TestCase): self.assertIsNone(self.interpolant(-0.5, 0.5)) def test_returns_float_for_float_and_array(self): + self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray) + self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray) self.assertIsInstance( - self.interpolant(0.5, np.random.random(10)), np.ndarray - ) - self.assertIsInstance( - self.interpolant(np.random.random(10), 0.5), np.ndarray - ) - self.assertIsInstance( - self.interpolant(np.random.random(10), np.random.random(10)), - np.ndarray + self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray ) def test_raises_for_mismatched_arrays(self): with self.assertRaises(ValueError): - self.interpolant( - np.random.random(10), np.random.random(20) - ) + self.interpolant(np.random.random(10), np.random.random(20)) def test_returns_fill_in_correct_place(self): x_data = np.random.random(10) @@ -357,5 +350,67 @@ class TestUnsortedInterp2d(unittest.TestCase): self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) +class TestTrapeziumRuleIntegration(unittest.TestCase): + def setUp(self): + self.x = np.linspace(0, 1, 100) + self.dxs = np.diff(self.x) + self.dx = self.dxs[0] + with np.errstate(divide="ignore"): + self.lnfunc1 = np.log(self.x) + self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 + with np.errstate(divide="ignore"): + self.lnfunc2 = np.log(self.x ** 2) + self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 + + self.irregularx = np.array( + [ + self.x[0], + self.x[12], + self.x[19], + self.x[33], + self.x[49], + self.x[55], + self.x[59], + self.x[61], + self.x[73], + self.x[89], + self.x[93], + self.x[97], + self.x[-1], + ] + ) + with np.errstate(divide="ignore"): + self.lnfunc1irregular = np.log(self.irregularx) + self.lnfunc2irregular = np.log(self.irregularx ** 2) + self.irregulardxs = np.diff(self.irregularx) + + def test_incorrect_step_type(self): + with self.assertRaises(TypeError): + utils.logtrapzexp(self.lnfunc1, "blah") + + def test_inconsistent_step_length(self): + with self.assertRaises(ValueError): + utils.logtrapzexp(self.lnfunc1, self.x[0 : len(self.x) // 2]) + + def test_integral_func1(self): + res1 = utils.logtrapzexp(self.lnfunc1, self.dx) + res2 = utils.logtrapzexp(self.lnfunc1, self.dxs) + + self.assertTrue(np.abs(res1 - res2) < 1e-12) + self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12) + + def test_integral_func2(self): + res = utils.logtrapzexp(self.lnfunc2, self.dxs) + self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4) + + def test_integral_func1_irregular_steps(self): + res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs) + self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12) + + def test_integral_func2_irregular_steps(self): + res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs) + self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2) + + if __name__ == "__main__": unittest.main()