Skip to content
Snippets Groups Projects
Commit 1e54ec66 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'fix_logtrapzexp' into 'master'

Allow logtrapzexp to work on a non-uniform grid

See merge request lscsoft/bilby!983
parents 17e6a80b c96cd97a
No related branches found
No related tags found
No related merge requests found
......@@ -191,8 +191,10 @@ class Grid(object):
places = self.sample_points[name]
if len(places) > 1:
dx = np.diff(places)
out = np.apply_along_axis(
logtrapzexp, axis, log_array, places[1] - places[0])
logtrapzexp, axis, log_array, dx
)
else:
# no marginalisation required, just remove the singleton dimension
z = log_array.shape
......
......@@ -6,8 +6,16 @@ from scipy.special import logsumexp
from .logger import logger
def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
epsscale=0.5, nonfixedidx=None):
def derivatives(
vals,
func,
releps=1e-3,
abseps=None,
mineps=1e-9,
reltol=1e-3,
epsscale=0.5,
nonfixedidx=None,
):
"""
Calculate the partial derivatives of a function at a set of values. The
derivatives are calculated using the central difference, using an iterative
......@@ -54,19 +62,19 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
grads = np.zeros(len(nonfixedidx))
# maximum number of times the gradient can change sign
flipflopmax = 10.
flipflopmax = 10.0
# set steps
if abseps is None:
if isinstance(releps, float):
eps = np.abs(vals) * releps
eps[eps == 0.] = releps # if any values are zero set eps to releps
eps[eps == 0.0] = releps # if any values are zero set eps to releps
teps = releps * np.ones(len(vals))
elif isinstance(releps, (list, np.ndarray)):
if len(releps) != len(vals):
raise ValueError("Problem with input relative step sizes")
eps = np.multiply(np.abs(vals), releps)
eps[eps == 0.] = np.array(releps)[eps == 0.]
eps[eps == 0.0] = np.array(releps)[eps == 0.0]
teps = releps
else:
raise RuntimeError("Relative step sizes are not a recognised type!")
......@@ -107,8 +115,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
cureps *= epsscale
if cureps < mineps or flipflop > flipflopmax:
# if no convergence set flat derivative (TODO: check if there is a better thing to do instead)
logger.warning("Derivative calculation did not converge: setting flat derivative.")
grads[count] = 0.
logger.warning(
"Derivative calculation did not converge: setting flat derivative."
)
grads[count] = 0.0
break
leps *= epsscale
......@@ -122,10 +132,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
break
# check whether previous diff and current diff are the same within reltol
rat = (cdiff / cdiffnew)
if np.isfinite(rat) and rat > 0.:
rat = cdiff / cdiffnew
if np.isfinite(rat) and rat > 0.0:
# gradient has not changed sign
if np.abs(1. - rat) < reltol:
if np.abs(1.0 - rat) < reltol:
grads[count] = cdiffnew
break
else:
......@@ -143,7 +153,7 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
def logtrapzexp(lnf, dx):
"""
Perform trapezium rule integration for the logarithm of a function on a regular grid.
Perform trapezium rule integration for the logarithm of a function on a grid.
Parameters
==========
......@@ -157,13 +167,30 @@ def logtrapzexp(lnf, dx):
=======
The natural logarithm of the area under the function.
"""
return np.log(dx / 2.) + logsumexp([logsumexp(lnf[:-1]), logsumexp(lnf[1:])])
lnfdx1 = lnf[:-1]
lnfdx2 = lnf[1:]
if isinstance(dx, (int, float)):
C = np.log(dx / 2.0)
elif isinstance(dx, (list, np.ndarray)):
if len(dx) != len(lnf) - 1:
raise ValueError(
"Step size array must have length one less than the function length"
)
lndx = np.log(dx)
lnfdx1 = lnfdx1.copy() + lndx
lnfdx2 = lnfdx2.copy() + lndx
C = -np.log(2.0)
else:
raise TypeError("Step size must be a single value or array-like")
return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)])
class UnsortedInterp2d(interp2d):
class UnsortedInterp2d(interp2d):
def __call__(self, x, y, dx=0, dy=0, assume_sorted=False):
""" Modified version of the interp2d call method.
"""Modified version of the interp2d call method.
This avoids the outer product that is done when two numpy
arrays are passed.
......@@ -184,6 +211,7 @@ class UnsortedInterp2d(interp2d):
"""
from scipy.interpolate.dfitpack import bispeu
x, y = self._sanitize_inputs(x, y)
out_of_bounds_x = (x < self.x_min) | (x > self.x_max)
out_of_bounds_y = (y < self.y_min) | (y > self.y_max)
......@@ -216,9 +244,7 @@ class UnsortedInterp2d(interp2d):
y = float(y)
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
if x.shape != y.shape:
raise ValueError(
"UnsortedInterp2d received unequally shaped arrays"
)
raise ValueError("UnsortedInterp2d received unequally shaped arrays")
elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray):
y = y * np.ones_like(x)
elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
......
......@@ -270,6 +270,7 @@ class TestLatexPlotFormat(unittest.TestCase):
fig, ax = plt.subplots()
ax.plot(self.x, self.y)
fig.savefig(self.filename)
plot()
self.assertTrue(os.path.isfile(self.filename))
......@@ -279,6 +280,7 @@ class TestLatexPlotFormat(unittest.TestCase):
fig, ax = plt.subplots()
ax.plot(self.x, self.y)
fig.savefig(self.filename)
plot(BILBY_MATHDEFAULT=1)
self.assertTrue(os.path.isfile(self.filename))
......@@ -288,6 +290,7 @@ class TestLatexPlotFormat(unittest.TestCase):
fig, ax = plt.subplots()
ax.plot(self.x, self.y)
fig.savefig(self.filename)
plot(BILBY_MATHDEFAULT=0)
self.assertTrue(os.path.isfile(self.filename))
......@@ -313,14 +316,11 @@ class TestLatexPlotFormat(unittest.TestCase):
class TestUnsortedInterp2d(unittest.TestCase):
def setUp(self):
self.xx = np.linspace(0, 1, 10)
self.yy = np.linspace(0, 1, 10)
self.zz = np.random.random((10, 10))
self.interpolant = bilby.core.utils.UnsortedInterp2d(
self.xx, self.yy, self.zz
)
self.interpolant = bilby.core.utils.UnsortedInterp2d(self.xx, self.yy, self.zz)
def tearDown(self):
pass
......@@ -333,22 +333,15 @@ class TestUnsortedInterp2d(unittest.TestCase):
self.assertIsNone(self.interpolant(-0.5, 0.5))
def test_returns_float_for_float_and_array(self):
self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray)
self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray)
self.assertIsInstance(
self.interpolant(0.5, np.random.random(10)), np.ndarray
)
self.assertIsInstance(
self.interpolant(np.random.random(10), 0.5), np.ndarray
)
self.assertIsInstance(
self.interpolant(np.random.random(10), np.random.random(10)),
np.ndarray
self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray
)
def test_raises_for_mismatched_arrays(self):
with self.assertRaises(ValueError):
self.interpolant(
np.random.random(10), np.random.random(20)
)
self.interpolant(np.random.random(10), np.random.random(20))
def test_returns_fill_in_correct_place(self):
x_data = np.random.random(10)
......@@ -357,5 +350,67 @@ class TestUnsortedInterp2d(unittest.TestCase):
self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3]))
class TestTrapeziumRuleIntegration(unittest.TestCase):
def setUp(self):
self.x = np.linspace(0, 1, 100)
self.dxs = np.diff(self.x)
self.dx = self.dxs[0]
with np.errstate(divide="ignore"):
self.lnfunc1 = np.log(self.x)
self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2
with np.errstate(divide="ignore"):
self.lnfunc2 = np.log(self.x ** 2)
self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3
self.irregularx = np.array(
[
self.x[0],
self.x[12],
self.x[19],
self.x[33],
self.x[49],
self.x[55],
self.x[59],
self.x[61],
self.x[73],
self.x[89],
self.x[93],
self.x[97],
self.x[-1],
]
)
with np.errstate(divide="ignore"):
self.lnfunc1irregular = np.log(self.irregularx)
self.lnfunc2irregular = np.log(self.irregularx ** 2)
self.irregulardxs = np.diff(self.irregularx)
def test_incorrect_step_type(self):
with self.assertRaises(TypeError):
utils.logtrapzexp(self.lnfunc1, "blah")
def test_inconsistent_step_length(self):
with self.assertRaises(ValueError):
utils.logtrapzexp(self.lnfunc1, self.x[0 : len(self.x) // 2])
def test_integral_func1(self):
res1 = utils.logtrapzexp(self.lnfunc1, self.dx)
res2 = utils.logtrapzexp(self.lnfunc1, self.dxs)
self.assertTrue(np.abs(res1 - res2) < 1e-12)
self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12)
def test_integral_func2(self):
res = utils.logtrapzexp(self.lnfunc2, self.dxs)
self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4)
def test_integral_func1_irregular_steps(self):
res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs)
self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12)
def test_integral_func2_irregular_steps(self):
res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs)
self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment