diff --git a/bilby/core/sampler/dynesty.py b/bilby/core/sampler/dynesty.py index 79f21fda034c42e2fd2f376dd482a6f92a8d18b2..5c81f45eed6ab050b9b1298a423289f7b82111fb 100644 --- a/bilby/core/sampler/dynesty.py +++ b/bilby/core/sampler/dynesty.py @@ -208,6 +208,11 @@ class Dynesty(NestedSampler): logger.debug("Setting reflective boundary for {}".format(key)) self._reflective.append(ii) + # The periodic kwargs passed into dynesty allows the parameters to + # wander out of the bounds, this includes both periodic and reflective. + # these are then handled in the prior_transform + self.kwargs["periodic"] = sorted(self._periodic + self._reflective) + def run_sampler(self): import dynesty if self.kwargs['live_points'] is None: diff --git a/test/sampler_test.py b/test/sampler_test.py index 9fc9060179f24b2cf2e8b0a54fb9275b3d26edfa..5835f5a154f94a1b4fa9f4708852e194df0d4bca 100644 --- a/test/sampler_test.py +++ b/test/sampler_test.py @@ -132,8 +132,8 @@ class TestDynesty(unittest.TestCase): def setUp(self): self.likelihood = MagicMock() self.priors = bilby.core.prior.PriorDict() - self.priors['a'] = bilby.core.prior.Prior(boundary='periodic') - self.priors['b'] = bilby.core.prior.Prior(boundary='reflective') + self.priors['a'] = bilby.core.prior.Prior() + self.priors['b'] = bilby.core.prior.Prior() self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors, outdir='outdir', label='label', use_ratio=False, plot=False, @@ -155,15 +155,15 @@ class TestDynesty(unittest.TestCase): logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False, walks=20, update_interval=600, print_func='func') self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise - self.assertListEqual([0, 1], self.sampler.kwargs['periodic']) # Check this separately - self.sampler.kwargs['periodic'] = None # The dict comparison can't handle lists + self.assertEqual([], self.sampler.kwargs['periodic']) # Check this separately + self.sampler.kwargs['periodic'] = expected['periodic'] # The dict comparison can't handle lists for key in self.sampler.kwargs.keys(): print("key={}, expected={}, actual={}" .format(key, expected[key], self.sampler.kwargs[key])) self.assertDictEqual(expected, self.sampler.kwargs) def test_translate_kwargs(self): - expected = dict(bound='multi', sample='rwalk', periodic=[0, 1], verbose=True, + expected = dict(bound='multi', sample='rwalk', periodic=[], verbose=True, check_point_delta_t=600, nlive=1000, first_update=None, npdim=None, rstate=None, queue_size=None, pool=None, use_pool=None, live_points=None, logl_args=None, logl_kwargs=None, @@ -181,6 +181,20 @@ class TestDynesty(unittest.TestCase): self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise self.assertDictEqual(expected, self.sampler.kwargs) + def test_prior_boundary(self): + self.priors['a'] = bilby.core.prior.Prior(boundary='periodic') + self.priors['b'] = bilby.core.prior.Prior(boundary='reflective') + self.priors['c'] = bilby.core.prior.Prior(boundary=None) + self.priors['d'] = bilby.core.prior.Prior(boundary='reflective') + self.priors['e'] = bilby.core.prior.Prior(boundary='periodic') + self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors, + outdir='outdir', label='label', + use_ratio=False, plot=False, + skip_import_verification=True) + self.assertEqual([0, 1, 3, 4], self.sampler.kwargs["periodic"]) + self.assertEqual([0, 4], self.sampler._periodic) + self.assertEqual([1, 3], self.sampler._reflective) + class TestEmcee(unittest.TestCase):