Skip to content
Snippets Groups Projects
Commit 3cc57b54 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Pass the periodic list through to dynesty

- Ensures the periodic argument is passed through to the dynesty making
  the points wander outside the allowed range
- Fix tests
parent a893d62e
No related branches found
No related tags found
1 merge request!589Implement dynesty none/reflective/periodic boundaries
Pipeline #78616 passed
......@@ -208,6 +208,11 @@ class Dynesty(NestedSampler):
logger.debug("Setting reflective boundary for {}".format(key))
self._reflective.append(ii)
# The periodic kwargs passed into dynesty allows the parameters to
# wander out of the bounds, this includes both periodic and reflective.
# these are then handled in the prior_transform
self.kwargs["periodic"] = sorted(self._periodic + self._reflective)
def run_sampler(self):
import dynesty
if self.kwargs['live_points'] is None:
......
......@@ -132,8 +132,8 @@ class TestDynesty(unittest.TestCase):
def setUp(self):
self.likelihood = MagicMock()
self.priors = bilby.core.prior.PriorDict()
self.priors['a'] = bilby.core.prior.Prior(boundary='periodic')
self.priors['b'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['a'] = bilby.core.prior.Prior()
self.priors['b'] = bilby.core.prior.Prior()
self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors,
outdir='outdir', label='label',
use_ratio=False, plot=False,
......@@ -155,15 +155,15 @@ class TestDynesty(unittest.TestCase):
logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
walks=20, update_interval=600, print_func='func')
self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise
self.assertListEqual([0, 1], self.sampler.kwargs['periodic']) # Check this separately
self.sampler.kwargs['periodic'] = None # The dict comparison can't handle lists
self.assertEqual([], self.sampler.kwargs['periodic']) # Check this separately
self.sampler.kwargs['periodic'] = expected['periodic'] # The dict comparison can't handle lists
for key in self.sampler.kwargs.keys():
print("key={}, expected={}, actual={}"
.format(key, expected[key], self.sampler.kwargs[key]))
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(bound='multi', sample='rwalk', periodic=[0, 1], verbose=True,
expected = dict(bound='multi', sample='rwalk', periodic=[], verbose=True,
check_point_delta_t=600, nlive=1000, first_update=None,
npdim=None, rstate=None, queue_size=None, pool=None,
use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
......@@ -181,6 +181,20 @@ class TestDynesty(unittest.TestCase):
self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise
self.assertDictEqual(expected, self.sampler.kwargs)
def test_prior_boundary(self):
self.priors['a'] = bilby.core.prior.Prior(boundary='periodic')
self.priors['b'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['c'] = bilby.core.prior.Prior(boundary=None)
self.priors['d'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['e'] = bilby.core.prior.Prior(boundary='periodic')
self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors,
outdir='outdir', label='label',
use_ratio=False, plot=False,
skip_import_verification=True)
self.assertEqual([0, 1, 3, 4], self.sampler.kwargs["periodic"])
self.assertEqual([0, 4], self.sampler._periodic)
self.assertEqual([1, 3], self.sampler._reflective)
class TestEmcee(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment