diff --git a/bilby/core/grid.py b/bilby/core/grid.py
index 8562e6d67b3afbaa863e85fe8125e62248a7bc41..8e264872d84fa2ef1daa9fc9a8b0d51c9de398a0 100644
--- a/bilby/core/grid.py
+++ b/bilby/core/grid.py
@@ -191,8 +191,10 @@ class Grid(object):
         places = self.sample_points[name]
 
         if len(places) > 1:
+            dx = np.diff(places)
             out = np.apply_along_axis(
-                logtrapzexp, axis, log_array, places[1] - places[0])
+                logtrapzexp, axis, log_array, dx
+            )
         else:
             # no marginalisation required, just remove the singleton dimension
             z = log_array.shape
diff --git a/bilby/core/utils/calculus.py b/bilby/core/utils/calculus.py
index fbd64f9f012e8983a20873f88c50db1953a0973c..ef7af61dec6a139e662cf53438cd07e2aaba6255 100644
--- a/bilby/core/utils/calculus.py
+++ b/bilby/core/utils/calculus.py
@@ -6,8 +6,16 @@ from scipy.special import logsumexp
 from .logger import logger
 
 
-def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
-                epsscale=0.5, nonfixedidx=None):
+def derivatives(
+    vals,
+    func,
+    releps=1e-3,
+    abseps=None,
+    mineps=1e-9,
+    reltol=1e-3,
+    epsscale=0.5,
+    nonfixedidx=None,
+):
     """
     Calculate the partial derivatives of a function at a set of values. The
     derivatives are calculated using the central difference, using an iterative
@@ -54,19 +62,19 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
     grads = np.zeros(len(nonfixedidx))
 
     # maximum number of times the gradient can change sign
-    flipflopmax = 10.
+    flipflopmax = 10.0
 
     # set steps
     if abseps is None:
         if isinstance(releps, float):
             eps = np.abs(vals) * releps
-            eps[eps == 0.] = releps  # if any values are zero set eps to releps
+            eps[eps == 0.0] = releps  # if any values are zero set eps to releps
             teps = releps * np.ones(len(vals))
         elif isinstance(releps, (list, np.ndarray)):
             if len(releps) != len(vals):
                 raise ValueError("Problem with input relative step sizes")
             eps = np.multiply(np.abs(vals), releps)
-            eps[eps == 0.] = np.array(releps)[eps == 0.]
+            eps[eps == 0.0] = np.array(releps)[eps == 0.0]
             teps = releps
         else:
             raise RuntimeError("Relative step sizes are not a recognised type!")
@@ -107,8 +115,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
             cureps *= epsscale
             if cureps < mineps or flipflop > flipflopmax:
                 # if no convergence set flat derivative (TODO: check if there is a better thing to do instead)
-                logger.warning("Derivative calculation did not converge: setting flat derivative.")
-                grads[count] = 0.
+                logger.warning(
+                    "Derivative calculation did not converge: setting flat derivative."
+                )
+                grads[count] = 0.0
                 break
             leps *= epsscale
 
@@ -122,10 +132,10 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
                 break
 
             # check whether previous diff and current diff are the same within reltol
-            rat = (cdiff / cdiffnew)
-            if np.isfinite(rat) and rat > 0.:
+            rat = cdiff / cdiffnew
+            if np.isfinite(rat) and rat > 0.0:
                 # gradient has not changed sign
-                if np.abs(1. - rat) < reltol:
+                if np.abs(1.0 - rat) < reltol:
                     grads[count] = cdiffnew
                     break
                 else:
@@ -143,7 +153,7 @@ def derivatives(vals, func, releps=1e-3, abseps=None, mineps=1e-9, reltol=1e-3,
 
 def logtrapzexp(lnf, dx):
     """
-    Perform trapezium rule integration for the logarithm of a function on a regular grid.
+    Perform trapezium rule integration for the logarithm of a function on a grid.
 
     Parameters
     ==========
@@ -157,13 +167,30 @@ def logtrapzexp(lnf, dx):
     =======
     The natural logarithm of the area under the function.
     """
-    return np.log(dx / 2.) + logsumexp([logsumexp(lnf[:-1]), logsumexp(lnf[1:])])
 
+    lnfdx1 = lnf[:-1]
+    lnfdx2 = lnf[1:]
+    if isinstance(dx, (int, float)):
+        C = np.log(dx / 2.0)
+    elif isinstance(dx, (list, np.ndarray)):
+        if len(dx) != len(lnf) - 1:
+            raise ValueError(
+                "Step size array must have length one less than the function length"
+            )
+
+        lndx = np.log(dx)
+        lnfdx1 = lnfdx1.copy() + lndx
+        lnfdx2 = lnfdx2.copy() + lndx
+        C = -np.log(2.0)
+    else:
+        raise TypeError("Step size must be a single value or array-like")
+
+    return C + logsumexp([logsumexp(lnfdx1), logsumexp(lnfdx2)])
 
-class UnsortedInterp2d(interp2d):
 
+class UnsortedInterp2d(interp2d):
     def __call__(self, x, y, dx=0, dy=0, assume_sorted=False):
-        """  Modified version of the interp2d call method.
+        """Modified version of the interp2d call method.
 
         This avoids the outer product that is done when two numpy
         arrays are passed.
@@ -184,6 +211,7 @@ class UnsortedInterp2d(interp2d):
 
         """
         from scipy.interpolate.dfitpack import bispeu
+
         x, y = self._sanitize_inputs(x, y)
         out_of_bounds_x = (x < self.x_min) | (x > self.x_max)
         out_of_bounds_y = (y < self.y_min) | (y > self.y_max)
@@ -216,9 +244,7 @@ class UnsortedInterp2d(interp2d):
             y = float(y)
         if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
             if x.shape != y.shape:
-                raise ValueError(
-                    "UnsortedInterp2d received unequally shaped arrays"
-                )
+                raise ValueError("UnsortedInterp2d received unequally shaped arrays")
         elif isinstance(x, np.ndarray) and not isinstance(y, np.ndarray):
             y = y * np.ones_like(x)
         elif not isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
diff --git a/test/core/utils_test.py b/test/core/utils_test.py
index 774311bda5559f715d5a73f1d181f19869db7325..d6b3e8902dac7c0ca77275a37a43e24b8e43d4e8 100644
--- a/test/core/utils_test.py
+++ b/test/core/utils_test.py
@@ -270,6 +270,7 @@ class TestLatexPlotFormat(unittest.TestCase):
             fig, ax = plt.subplots()
             ax.plot(self.x, self.y)
             fig.savefig(self.filename)
+
         plot()
         self.assertTrue(os.path.isfile(self.filename))
 
@@ -279,6 +280,7 @@ class TestLatexPlotFormat(unittest.TestCase):
             fig, ax = plt.subplots()
             ax.plot(self.x, self.y)
             fig.savefig(self.filename)
+
         plot(BILBY_MATHDEFAULT=1)
         self.assertTrue(os.path.isfile(self.filename))
 
@@ -288,6 +290,7 @@ class TestLatexPlotFormat(unittest.TestCase):
             fig, ax = plt.subplots()
             ax.plot(self.x, self.y)
             fig.savefig(self.filename)
+
         plot(BILBY_MATHDEFAULT=0)
         self.assertTrue(os.path.isfile(self.filename))
 
@@ -313,14 +316,11 @@ class TestLatexPlotFormat(unittest.TestCase):
 
 
 class TestUnsortedInterp2d(unittest.TestCase):
-
     def setUp(self):
         self.xx = np.linspace(0, 1, 10)
         self.yy = np.linspace(0, 1, 10)
         self.zz = np.random.random((10, 10))
-        self.interpolant = bilby.core.utils.UnsortedInterp2d(
-            self.xx, self.yy, self.zz
-        )
+        self.interpolant = bilby.core.utils.UnsortedInterp2d(self.xx, self.yy, self.zz)
 
     def tearDown(self):
         pass
@@ -333,22 +333,15 @@ class TestUnsortedInterp2d(unittest.TestCase):
         self.assertIsNone(self.interpolant(-0.5, 0.5))
 
     def test_returns_float_for_float_and_array(self):
+        self.assertIsInstance(self.interpolant(0.5, np.random.random(10)), np.ndarray)
+        self.assertIsInstance(self.interpolant(np.random.random(10), 0.5), np.ndarray)
         self.assertIsInstance(
-            self.interpolant(0.5, np.random.random(10)), np.ndarray
-        )
-        self.assertIsInstance(
-            self.interpolant(np.random.random(10), 0.5), np.ndarray
-        )
-        self.assertIsInstance(
-            self.interpolant(np.random.random(10), np.random.random(10)),
-            np.ndarray
+            self.interpolant(np.random.random(10), np.random.random(10)), np.ndarray
         )
 
     def test_raises_for_mismatched_arrays(self):
         with self.assertRaises(ValueError):
-            self.interpolant(
-                np.random.random(10), np.random.random(20)
-            )
+            self.interpolant(np.random.random(10), np.random.random(20))
 
     def test_returns_fill_in_correct_place(self):
         x_data = np.random.random(10)
@@ -357,5 +350,67 @@ class TestUnsortedInterp2d(unittest.TestCase):
         self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3]))
 
 
+class TestTrapeziumRuleIntegration(unittest.TestCase):
+    def setUp(self):
+        self.x = np.linspace(0, 1, 100)
+        self.dxs = np.diff(self.x)
+        self.dx = self.dxs[0]
+        with np.errstate(divide="ignore"):
+            self.lnfunc1 = np.log(self.x)
+        self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2
+        with np.errstate(divide="ignore"):
+            self.lnfunc2 = np.log(self.x ** 2)
+        self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3
+
+        self.irregularx = np.array(
+            [
+                self.x[0],
+                self.x[12],
+                self.x[19],
+                self.x[33],
+                self.x[49],
+                self.x[55],
+                self.x[59],
+                self.x[61],
+                self.x[73],
+                self.x[89],
+                self.x[93],
+                self.x[97],
+                self.x[-1],
+            ]
+        )
+        with np.errstate(divide="ignore"):
+            self.lnfunc1irregular = np.log(self.irregularx)
+            self.lnfunc2irregular = np.log(self.irregularx ** 2)
+        self.irregulardxs = np.diff(self.irregularx)
+
+    def test_incorrect_step_type(self):
+        with self.assertRaises(TypeError):
+            utils.logtrapzexp(self.lnfunc1, "blah")
+
+    def test_inconsistent_step_length(self):
+        with self.assertRaises(ValueError):
+            utils.logtrapzexp(self.lnfunc1, self.x[0 : len(self.x) // 2])
+
+    def test_integral_func1(self):
+        res1 = utils.logtrapzexp(self.lnfunc1, self.dx)
+        res2 = utils.logtrapzexp(self.lnfunc1, self.dxs)
+
+        self.assertTrue(np.abs(res1 - res2) < 1e-12)
+        self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12)
+
+    def test_integral_func2(self):
+        res = utils.logtrapzexp(self.lnfunc2, self.dxs)
+        self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4)
+
+    def test_integral_func1_irregular_steps(self):
+        res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs)
+        self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12)
+
+    def test_integral_func2_irregular_steps(self):
+        res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs)
+        self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2)
+
+
 if __name__ == "__main__":
     unittest.main()