diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index efb25f260dcfe3f3e6ed7a0f5b8a8eea40d9f7e5..4ec703290f93e530141784ff895180cdd3f1cd74 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -10,7 +10,7 @@ import matplotlib.pyplot as plt import numpy as np from pandas import DataFrame -from ..utils import logger, check_directory_exists_and_if_not_mkdir +from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect from .base_sampler import Sampler, NestedSampler @@ -216,6 +216,7 @@ class Dynesty(NestedSampler): def run_sampler(self): import dynesty + logger.info("Using dynesty version {}".format(dynesty.__version__)) if self.kwargs['live_points'] is None: self.kwargs['live_points'] = ( self.get_initial_points_from_prior( @@ -509,4 +510,5 @@ class Dynesty(NestedSampler): |theta| - 1 (i.e. wrap around). """ + theta[self._reflective] = reflect(theta[self._reflective]) return self.priors.rescale(self._search_parameter_keys, theta) diff --git a/bilby/core/utils.py b/bilby/core/utils.py index a406c5c6818cb7edaf7f4d3a445a619c732b61be..9fa3f31775ee2f3f23db5245982463397d8e9884 100644 --- a/bilby/core/utils.py +++ b/bilby/core/utils.py @@ -1004,6 +1004,33 @@ def decode_astropy_quantity(dct): return dct +def reflect(u): + """ + Iteratively reflect a number until it is contained in [0, 1]. + + This is for priors with a reflective boundary condition, all numbers in the set `u = 2n +/- x` should be mapped to x. + + For the `+` case we just take `u % 1`. + For the `-` case we take `1 - (u % 1)`. + + E.g., -0.9, 1.1, and 2.9 should all map to 0.9. + + Parameters + ---------- + u: array-like + The array of points to map to the unit cube + + Returns + ------- + u: array-like + The input array, modified in place. + """ + idxs_even = np.mod(u, 2) < 1 + u[idxs_even] = np.mod(u[idxs_even], 1) + u[~idxs_even] = 1 - np.mod(u[~idxs_even], 1) + return u + + class IllegalDurationAndSamplingFrequencyException(Exception): pass diff --git a/test/utils_test.py b/test/utils_test.py index 5ac7e1b9638e60fbddbf392e36330f03e578e685..a99a3c2705ed448e0c3a222943f7a7bf9383011b 100644 --- a/test/utils_test.py +++ b/test/utils_test.py @@ -165,5 +165,37 @@ class TestTimeAndFrequencyArrays(unittest.TestCase): starting_time=0) +class TestReflect(unittest.TestCase): + + def test_in_range(self): + xprime = np.array([0.1, 0.5, 0.9]) + x = np.array([0.1, 0.5, 0.9]) + self.assertTrue( + np.testing.assert_allclose(utils.reflect(xprime), x) is None) + + def test_in_one_to_two(self): + xprime = np.array([1.1, 1.5, 1.9]) + x = np.array([0.9, 0.5, 0.1]) + self.assertTrue( + np.testing.assert_allclose(utils.reflect(xprime), x) is None) + + def test_in_two_to_three(self): + xprime = np.array([2.1, 2.5, 2.9]) + x = np.array([0.1, 0.5, 0.9]) + self.assertTrue( + np.testing.assert_allclose(utils.reflect(xprime), x) is None) + + def test_in_minus_one_to_zero(self): + xprime = np.array([-0.9, -0.5, -0.1]) + x = np.array([0.9, 0.5, 0.1]) + self.assertTrue( + np.testing.assert_allclose(utils.reflect(xprime), x) is None) + + def test_in_minus_two_to_minus_one(self): + xprime = np.array([-1.9, -1.5, -1.1]) + x = np.array([0.1, 0.5, 0.9]) + self.assertTrue( + np.testing.assert_allclose(utils.reflect(xprime), x) is None) + if __name__ == '__main__': unittest.main()