diff --git a/test/core/utils_test.py b/test/core/utils_test.py index 774311bda5559f715d5a73f1d181f19869db7325..1d6b87306fa110422e33a41a2e575dea5b0bff84 100644 --- a/test/core/utils_test.py +++ b/test/core/utils_test.py @@ -357,5 +357,67 @@ class TestUnsortedInterp2d(unittest.TestCase): self.assertTrue(np.isnan(self.interpolant(x_data, y_data)[3])) +class TestTrapeziumRuleIntegration(unittest.TestCase): + def setUp(self): + self.x = np.linspace(0, 1, 100) + self.dxs = np.diff(self.x) + self.dx = self.dxs[0] + with np.errstate(divide="ignore"): + self.lnfunc1 = np.log(self.x) + self.func1int = (self.x[-1] ** 2 - self.x[0] ** 2) / 2 + with np.errstate(divide="ignore"): + self.lnfunc2 = np.log(self.x ** 2) + self.func2int = (self.x[-1] ** 3 - self.x[0] ** 3) / 3 + + self.irregularx = np.array( + [ + self.x[0], + self.x[12], + self.x[19], + self.x[33], + self.x[49], + self.x[55], + self.x[59], + self.x[61], + self.x[73], + self.x[89], + self.x[93], + self.x[97], + self.x[-1], + ] + ) + with np.errstate(divide="ignore"): + self.lnfunc1irregular = np.log(self.irregularx) + self.lnfunc2irregular = np.log(self.irregularx ** 2) + self.irregulardxs = np.diff(self.irregularx) + + def test_incorrect_step_type(self): + with self.assertRaises(TypeError): + utils.logtrapzexp(self.lnfunc1, "blah") + + def test_inconsistent_step_length(self): + with self.assertRaises(ValueError): + utils.logtrapzexp(self.lnfunc1, self.x[0:len(self.x) // 2]) + + def test_integral_func1(self): + res1 = utils.logtrapzexp(self.lnfunc1, self.dx) + res2 = utils.logtrapzexp(self.lnfunc1, self.dxs) + + self.assertTrue(np.abs(res1 - res2) < 1e-12) + self.assertTrue(np.abs((np.exp(res1) - self.func1int) / self.func1int) < 1e-12) + + def test_integral_func2(self): + res = utils.logtrapzexp(self.lnfunc2, self.dxs) + self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-4) + + def test_integral_func1_irregular_steps(self): + res = utils.logtrapzexp(self.lnfunc1irregular, self.irregulardxs) + self.assertTrue(np.abs((np.exp(res) - self.func1int) / self.func1int) < 1e-12) + + def test_integral_func2_irregular_steps(self): + res = utils.logtrapzexp(self.lnfunc2irregular, self.irregulardxs) + self.assertTrue(np.abs((np.exp(res) - self.func2int) / self.func2int) < 1e-2) + + if __name__ == "__main__": unittest.main()