Skip to content
Snippets Groups Projects
Commit e2e043f1 authored by Gregory Ashton's avatar Gregory Ashton
Browse files

Adding a custom dynesty sampler

parent ea21ab50
No related branches found
Tags 0.6.0
No related merge requests found
......@@ -14,6 +14,11 @@ from pandas import DataFrame
from ..utils import logger, check_directory_exists_and_if_not_mkdir, reflect
from .base_sampler import Sampler, NestedSampler
from numpy import linalg
from dynesty.utils import unitcheck
import warnings
import math
class Dynesty(NestedSampler):
"""
......@@ -225,6 +230,10 @@ class Dynesty(NestedSampler):
self.kwargs['live_points'] = (
self.get_initial_points_from_prior(
self.kwargs['nlive']))
dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_bilby
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_bilby
self.sampler = dynesty.NestedSampler(
loglikelihood=self.log_likelihood,
prior_transform=self.prior_transform,
......@@ -510,15 +519,110 @@ class Dynesty(NestedSampler):
-------
list: Properly rescaled sampled values
Notes
-----
Since dynesty allows periodic parameters to wander outside the unit,
here we transform them depending of if they should be periodic or
reflective. For reflective boundaries, theta < 0 you shift to |theta|
and when theta > 1 you return 2 - theta. For periodic boundaries,
if theta < 0, you shift to 1-|theta| and when theta > 1 you shift to
|theta| - 1 (i.e. wrap around).
"""
theta[self._reflective] = reflect(theta[self._reflective])
return self.priors.rescale(self._search_parameter_keys, theta)
def sample_rwalk_bilby(args):
"""
Modified version of dynesty.sampling.sample_rwalk
"""
# Unzipping.
(u, loglstar, axes, scale,
prior_transform, loglikelihood, kwargs) = args
rstate = np.random
# Bounds
nonbounded = kwargs.get('nonbounded', None)
periodic = kwargs.get('periodic', None)
reflective = kwargs.get('reflective', None)
# Setup.
n = len(u)
walks = kwargs.get('walks', 25) # number of steps
accept = 0
reject = 0
fail = 0
nfail = 0
nc = 0
ncall = 0
drhat, dr, du, u_prop, logl_prop = np.nan, np.nan, np.nan, np.nan, np.nan
while nc + nfail < walks or accept == 0:
while True:
# Check scale-factor.
if scale == 0.:
raise RuntimeError("The random walk sampling is stuck! "
"Some useful output quantities:\n"
"u: {0}\n"
"drhat: {1}\n"
"dr: {2}\n"
"du: {3}\n"
"u_prop: {4}\n"
"loglstar: {5}\n"
"logl_prop: {6}\n"
"axes: {7}\n"
"scale: {8}."
.format(u, drhat, dr, du, u_prop,
loglstar, logl_prop, axes, scale))
# Propose a direction on the unit n-sphere.
drhat = rstate.randn(n)
drhat /= linalg.norm(drhat)
# Scale based on dimensionality.
dr = drhat * rstate.rand()**(1. / n)
# Transform to proposal distribution.
du = np.dot(axes, dr)
u_prop = u + scale * du
# Wrap periodic parameters
if periodic is not None:
u_prop[periodic] = np.mod(u_prop[periodic], 1)
# Reflect
if reflective is not None:
u_prop[reflective] = reflect(u_prop[reflective])
# Check unit cube constraints.
if unitcheck(u_prop, nonbounded):
break
else:
fail += 1
nfail += 1
# Check if we're stuck generating bad numbers.
if fail > 100 * walks:
warnings.warn("Random number generation appears to be "
"extremely inefficient. Adjusting the "
"scale-factor accordingly.")
fail = 0
scale *= math.exp(-1. / n)
# Check proposed point.
v_prop = prior_transform(np.array(u_prop))
logl_prop = loglikelihood(np.array(v_prop))
if logl_prop >= loglstar:
u = u_prop
v = v_prop
logl = logl_prop
accept += 1
else:
reject += 1
nc += 1
ncall += 1
# Check if we're stuck generating bad points.
if nc > 50 * walks:
scale *= math.exp(-1. / n)
warnings.warn("Random walk proposals appear to be "
"extremely inefficient. Adjusting the "
"scale-factor accordingly.")
nc, accept, reject = 0, 0, 0 # reset values
blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale}
return u, v, logl, ncall, blob
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment