Skip to content
Snippets Groups Projects
Commit 6ad057be authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Merge branch 'make-reflective-boundaries-optional-in-dynesty'

parents fa9e0e2b 9d8dd886
No related branches found
No related tags found
1 merge request!589Implement dynesty none/reflective/periodic boundaries
Pipeline #80926 failed
......@@ -75,7 +75,7 @@ class Dynesty(NestedSampler):
If true, resume run from checkpoint (if available)
"""
default_kwargs = dict(bound='multi', sample='rwalk',
verbose=True, periodic=None,
verbose=True, periodic=None, reflective=None,
check_point_delta_t=600, nlive=1000,
first_update=None, walks=None,
npdim=None, rstate=None, queue_size=None, pool=None,
......@@ -198,19 +198,21 @@ class Dynesty(NestedSampler):
sys.stdout.flush()
def _apply_dynesty_boundaries(self):
if self.kwargs['periodic'] is None:
logger.debug("Setting periodic boundaries for keys:")
self.kwargs['periodic'] = []
self._periodic = list()
self._reflective = list()
for ii, key in enumerate(self.search_parameter_keys):
if self.priors[key].boundary in ['periodic', 'reflective']:
self.kwargs['periodic'].append(ii)
logger.debug(" {}".format(key))
if self.priors[key].boundary == 'periodic':
self._periodic.append(ii)
else:
self._reflective.append(ii)
self._periodic = list()
self._reflective = list()
for ii, key in enumerate(self.search_parameter_keys):
if self.priors[key].boundary == 'periodic':
logger.debug("Setting periodic boundary for {}".format(key))
self._periodic.append(ii)
elif self.priors[key].boundary == 'reflective':
logger.debug("Setting reflective boundary for {}".format(key))
self._reflective.append(ii)
# The periodic kwargs passed into dynesty allows the parameters to
# wander out of the bounds, this includes both periodic and reflective.
# these are then handled in the prior_transform
self.kwargs["periodic"] = self._periodic
self.kwargs["reflective"] = self._reflective
def run_sampler(self):
import dynesty
......@@ -505,8 +507,4 @@ class Dynesty(NestedSampler):
|theta| - 1 (i.e. wrap around).
"""
theta[self._periodic] = np.mod(theta[self._periodic], 1)
theta_ref = theta[self._reflective]
theta[self._reflective] = np.minimum(
np.maximum(theta_ref, abs(theta_ref)), 2 - theta_ref)
return self.priors.rescale(self._search_parameter_keys, theta)
......@@ -78,7 +78,7 @@ setup(name='bilby',
'bilby': [version_file]},
install_requires=[
'future',
'dynesty>=0.9.7',
'dynesty>=1.0.0',
'corner',
'dill',
'numpy>=1.9',
......
......@@ -134,8 +134,8 @@ class TestDynesty(unittest.TestCase):
def setUp(self):
self.likelihood = MagicMock()
self.priors = bilby.core.prior.PriorDict()
self.priors['a'] = bilby.core.prior.Prior(boundary='periodic')
self.priors['b'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['a'] = bilby.core.prior.Prior()
self.priors['b'] = bilby.core.prior.Prior()
self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors,
outdir='outdir', label='label',
use_ratio=False, plot=False,
......@@ -147,7 +147,7 @@ class TestDynesty(unittest.TestCase):
del self.sampler
def test_default_kwargs(self):
expected = dict(bound='multi', sample='rwalk', periodic=None, verbose=True,
expected = dict(bound='multi', sample='rwalk', periodic=None, reflective=None, verbose=True,
check_point_delta_t=600, nlive=1000, first_update=None,
npdim=None, rstate=None, queue_size=None, pool=None,
use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
......@@ -157,15 +157,18 @@ class TestDynesty(unittest.TestCase):
logl_max=np.inf, add_live=True, print_progress=True, save_bounds=False,
walks=20, update_interval=600, print_func='func', n_effective=None)
self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise
self.assertListEqual([0, 1], self.sampler.kwargs['periodic']) # Check this separately
self.sampler.kwargs['periodic'] = None # The dict comparison can't handle lists
# DictEqual can't handle lists so we check these separately
self.assertEqual([], self.sampler.kwargs['periodic'])
self.assertEqual([], self.sampler.kwargs['reflective'])
self.sampler.kwargs['periodic'] = expected['periodic']
self.sampler.kwargs['reflective'] = expected['reflective']
for key in self.sampler.kwargs.keys():
print("key={}, expected={}, actual={}"
.format(key, expected[key], self.sampler.kwargs[key]))
self.assertDictEqual(expected, self.sampler.kwargs)
def test_translate_kwargs(self):
expected = dict(bound='multi', sample='rwalk', periodic=[0, 1], verbose=True,
expected = dict(bound='multi', sample='rwalk', periodic=[], reflective=[], verbose=True,
check_point_delta_t=600, nlive=1000, first_update=None,
npdim=None, rstate=None, queue_size=None, pool=None,
use_pool=None, live_points=None, logl_args=None, logl_kwargs=None,
......@@ -183,6 +186,21 @@ class TestDynesty(unittest.TestCase):
self.sampler.kwargs['print_func'] = 'func' # set this manually as this is not testable otherwise
self.assertDictEqual(expected, self.sampler.kwargs)
def test_prior_boundary(self):
self.priors['a'] = bilby.core.prior.Prior(boundary='periodic')
self.priors['b'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['c'] = bilby.core.prior.Prior(boundary=None)
self.priors['d'] = bilby.core.prior.Prior(boundary='reflective')
self.priors['e'] = bilby.core.prior.Prior(boundary='periodic')
self.sampler = bilby.core.sampler.Dynesty(self.likelihood, self.priors,
outdir='outdir', label='label',
use_ratio=False, plot=False,
skip_import_verification=True)
self.assertEqual([0, 4], self.sampler.kwargs["periodic"])
self.assertEqual(self.sampler._periodic, self.sampler.kwargs["periodic"])
self.assertEqual([1, 3], self.sampler.kwargs["reflective"])
self.assertEqual(self.sampler._reflective, self.sampler.kwargs["reflective"])
class TestEmcee(unittest.TestCase):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment